1use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::ConstantArray;
18use crate::arrays::ConstantVTable;
19use crate::arrays::ScalarFnVTable;
20use crate::arrays::scalar_fn::ExactScalarFn;
21use crate::arrays::scalar_fn::ScalarFnArrayView;
22use crate::arrow::Datum;
23use crate::arrow::IntoArrowArray;
24use crate::arrow::from_arrow_array_with_len;
25use crate::dtype::DType;
26use crate::dtype::Nullability;
27use crate::kernel::ExecuteParentKernel;
28use crate::scalar::Scalar;
29use crate::scalar_fn::fns::binary::Binary;
30use crate::scalar_fn::fns::operators::CompareOperator;
31use crate::vtable::VTable;
32
33pub trait CompareKernel: VTable {
39 fn compare(
40 lhs: &Self::Array,
41 rhs: &ArrayRef,
42 operator: CompareOperator,
43 ctx: &mut ExecutionCtx,
44 ) -> VortexResult<Option<ArrayRef>>;
45}
46
47#[derive(Default, Debug)]
53pub struct CompareExecuteAdaptor<V>(pub V);
54
55impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
56where
57 V: CompareKernel,
58{
59 type Parent = ExactScalarFn<Binary>;
60
61 fn execute_parent(
62 &self,
63 array: &V::Array,
64 parent: ScalarFnArrayView<'_, Binary>,
65 child_idx: usize,
66 ctx: &mut ExecutionCtx,
67 ) -> VortexResult<Option<ArrayRef>> {
68 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
70 return Ok(None);
71 };
72
73 let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
75 return Ok(None);
76 };
77 let children = scalar_fn_array.children();
78
79 let (cmp_op, other) = match child_idx {
82 0 => (cmp_op, &children[1]),
83 1 => (cmp_op.swap(), &children[0]),
84 _ => return Ok(None),
85 };
86
87 let len = array.len();
88 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
89
90 if len == 0 {
92 return Ok(Some(
93 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
94 ));
95 }
96
97 if other.as_constant().is_some_and(|s| s.is_null()) {
99 return Ok(Some(
100 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
101 ));
102 }
103
104 V::compare(array, other, cmp_op, ctx)
105 }
106}
107
108pub(crate) fn execute_compare(
113 lhs: &ArrayRef,
114 rhs: &ArrayRef,
115 op: CompareOperator,
116) -> VortexResult<ArrayRef> {
117 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
118
119 if lhs.is_empty() {
120 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
121 }
122
123 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
124 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
125 if left_constant_null || right_constant_null {
126 return Ok(
127 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
128 );
129 }
130
131 if let (Some(lhs_const), Some(rhs_const)) = (
133 lhs.as_opt::<ConstantVTable>(),
134 rhs.as_opt::<ConstantVTable>(),
135 ) {
136 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
137 return Ok(ConstantArray::new(result, lhs.len()).into_array());
138 }
139
140 arrow_compare_arrays(lhs, rhs, op)
141}
142
143fn arrow_compare_arrays(
145 left: &ArrayRef,
146 right: &ArrayRef,
147 operator: CompareOperator,
148) -> VortexResult<ArrayRef> {
149 assert_eq!(left.len(), right.len());
150
151 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
152
153 let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
156 let rhs = right.to_array().into_arrow_preferred()?;
157 let lhs = left.to_array().into_arrow(rhs.data_type())?;
158
159 assert!(
160 lhs.data_type().equals_datatype(rhs.data_type()),
161 "lhs data_type: {}, rhs data_type: {}",
162 lhs.data_type(),
163 rhs.data_type()
164 );
165
166 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
167 } else {
168 let lhs = Datum::try_new(left)?;
170 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
171
172 match operator {
173 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
174 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
175 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
176 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
177 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
178 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
179 }
180 };
181 from_arrow_array_with_len(&array, left.len(), nullable)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
185 if lhs.is_null() | rhs.is_null() {
186 Scalar::null(DType::Bool(Nullability::Nullable))
187 } else {
188 let b = match operator {
189 CompareOperator::Eq => lhs == rhs,
190 CompareOperator::NotEq => lhs != rhs,
191 CompareOperator::Gt => lhs > rhs,
192 CompareOperator::Gte => lhs >= rhs,
193 CompareOperator::Lt => lhs < rhs,
194 CompareOperator::Lte => lhs <= rhs,
195 };
196
197 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
198 }
199}
200
201pub fn compare_nested_arrow_arrays(
209 lhs: &dyn arrow_array::Array,
210 rhs: &dyn arrow_array::Array,
211 operator: CompareOperator,
212) -> VortexResult<BooleanArray> {
213 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
214
215 let cmp_fn = match operator {
216 CompareOperator::Eq => Ordering::is_eq,
217 CompareOperator::NotEq => Ordering::is_ne,
218 CompareOperator::Gt => Ordering::is_gt,
219 CompareOperator::Gte => Ordering::is_ge,
220 CompareOperator::Lt => Ordering::is_lt,
221 CompareOperator::Lte => Ordering::is_le,
222 };
223
224 let values = (0..lhs.len())
225 .map(|i| cmp_fn(compare_arrays_at(i, i)))
226 .collect();
227 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
228
229 Ok(BooleanArray::new(values, nulls))
230}
231
232#[cfg(test)]
233mod tests {
234 use std::sync::Arc;
235
236 use rstest::rstest;
237 use vortex_buffer::buffer;
238
239 use crate::ArrayRef;
240 use crate::IntoArray;
241 use crate::ToCanonical;
242 use crate::arrays::BoolArray;
243 use crate::arrays::ListArray;
244 use crate::arrays::ListViewArray;
245 use crate::arrays::PrimitiveArray;
246 use crate::arrays::StructArray;
247 use crate::arrays::VarBinArray;
248 use crate::arrays::VarBinViewArray;
249 use crate::assert_arrays_eq;
250 use crate::builtins::ArrayBuiltins;
251 use crate::dtype::DType;
252 use crate::dtype::FieldName;
253 use crate::dtype::FieldNames;
254 use crate::dtype::Nullability;
255 use crate::dtype::PType;
256 use crate::scalar::Scalar;
257 use crate::scalar_fn::fns::binary::compare::ConstantArray;
258 use crate::scalar_fn::fns::operators::Operator;
259 use crate::test_harness::to_int_indices;
260 use crate::validity::Validity;
261
262 #[test]
263 fn test_bool_basic_comparisons() {
264 use vortex_buffer::BitBuffer;
265
266 let arr = BoolArray::new(
267 BitBuffer::from_iter([true, true, false, true, false]),
268 Validity::from_iter([false, true, true, true, true]),
269 );
270
271 let matches = arr
272 .clone()
273 .into_array()
274 .binary(arr.clone().into_array(), Operator::Eq)
275 .unwrap()
276 .to_bool();
277 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
278
279 let matches = arr
280 .clone()
281 .into_array()
282 .binary(arr.clone().into_array(), Operator::NotEq)
283 .unwrap()
284 .to_bool();
285 let empty: [u64; 0] = [];
286 assert_eq!(to_int_indices(matches).unwrap(), empty);
287
288 let other = BoolArray::new(
289 BitBuffer::from_iter([false, false, false, true, true]),
290 Validity::from_iter([false, true, true, true, true]),
291 );
292
293 let matches = arr
294 .clone()
295 .into_array()
296 .binary(other.clone().into_array(), Operator::Lte)
297 .unwrap()
298 .to_bool();
299 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
300
301 let matches = arr
302 .clone()
303 .into_array()
304 .binary(other.clone().into_array(), Operator::Lt)
305 .unwrap()
306 .to_bool();
307 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
308
309 let matches = other
310 .clone()
311 .into_array()
312 .binary(arr.clone().into_array(), Operator::Gte)
313 .unwrap()
314 .to_bool();
315 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
316
317 let matches = other
318 .into_array()
319 .binary(arr.into_array(), Operator::Gt)
320 .unwrap()
321 .to_bool();
322 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
323 }
324
325 #[test]
326 fn constant_compare() {
327 let left = ConstantArray::new(Scalar::from(2u32), 10);
328 let right = ConstantArray::new(Scalar::from(10u32), 10);
329
330 let result = left
331 .into_array()
332 .binary(right.into_array(), Operator::Gt)
333 .unwrap();
334 assert_eq!(result.len(), 10);
335 let scalar = result.scalar_at(0).unwrap();
336 assert_eq!(scalar.as_bool().value(), Some(false));
337 }
338
339 #[rstest]
340 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
341 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
342 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
343 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
344 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
345 let res = left.binary(right, Operator::Eq).unwrap();
346 let expected = BoolArray::from_iter([true, true]);
347 assert_arrays_eq!(res, expected);
348 }
349
350 #[ignore = "Arrow's ListView cannot be compared"]
351 #[test]
352 fn test_list_array_comparison() {
353 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
354 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
355 let list1 = ListArray::try_new(
356 values1.into_array(),
357 offsets1.into_array(),
358 Validity::NonNullable,
359 )
360 .unwrap();
361
362 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
363 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
364 let list2 = ListArray::try_new(
365 values2.into_array(),
366 offsets2.into_array(),
367 Validity::NonNullable,
368 )
369 .unwrap();
370
371 let result = list1
372 .clone()
373 .into_array()
374 .binary(list2.clone().into_array(), Operator::Eq)
375 .unwrap();
376 let expected = BoolArray::from_iter([true, true, false]);
377 assert_arrays_eq!(result, expected);
378
379 let result = list1
380 .clone()
381 .into_array()
382 .binary(list2.clone().into_array(), Operator::NotEq)
383 .unwrap();
384 let expected = BoolArray::from_iter([false, false, true]);
385 assert_arrays_eq!(result, expected);
386
387 let result = list1
388 .into_array()
389 .binary(list2.into_array(), Operator::Lt)
390 .unwrap();
391 let expected = BoolArray::from_iter([false, false, true]);
392 assert_arrays_eq!(result, expected);
393 }
394
395 #[ignore = "Arrow's ListView cannot be compared"]
396 #[test]
397 fn test_list_array_constant_comparison() {
398 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
399 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
400 let list = ListArray::try_new(
401 values.into_array(),
402 offsets.into_array(),
403 Validity::NonNullable,
404 )
405 .unwrap();
406
407 let list_scalar = Scalar::list(
408 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
409 vec![3i32.into(), 4i32.into()],
410 Nullability::NonNullable,
411 );
412 let constant = ConstantArray::new(list_scalar, 3);
413
414 let result = list
415 .into_array()
416 .binary(constant.into_array(), Operator::Eq)
417 .unwrap();
418 let expected = BoolArray::from_iter([false, true, false]);
419 assert_arrays_eq!(result, expected);
420 }
421
422 #[test]
423 fn test_struct_array_comparison() {
424 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
425 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
426
427 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
428 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
429
430 let struct1 = StructArray::from_fields(&[
431 ("bool_col", bool_field1.into_array()),
432 ("int_col", int_field1.into_array()),
433 ])
434 .unwrap();
435
436 let struct2 = StructArray::from_fields(&[
437 ("bool_col", bool_field2.into_array()),
438 ("int_col", int_field2.into_array()),
439 ])
440 .unwrap();
441
442 let result = struct1
443 .clone()
444 .into_array()
445 .binary(struct2.clone().into_array(), Operator::Eq)
446 .unwrap();
447 let expected = BoolArray::from_iter([true, true, false]);
448 assert_arrays_eq!(result, expected);
449
450 let result = struct1
451 .into_array()
452 .binary(struct2.into_array(), Operator::Gt)
453 .unwrap();
454 let expected = BoolArray::from_iter([false, false, true]);
455 assert_arrays_eq!(result, expected);
456 }
457
458 #[test]
459 fn test_empty_struct_compare() {
460 let empty1 = StructArray::try_new(
461 FieldNames::from(Vec::<FieldName>::new()),
462 Vec::new(),
463 5,
464 Validity::NonNullable,
465 )
466 .unwrap();
467
468 let empty2 = StructArray::try_new(
469 FieldNames::from(Vec::<FieldName>::new()),
470 Vec::new(),
471 5,
472 Validity::NonNullable,
473 )
474 .unwrap();
475
476 let result = empty1
477 .into_array()
478 .binary(empty2.into_array(), Operator::Eq)
479 .unwrap();
480 let expected = BoolArray::from_iter([true, true, true, true, true]);
481 assert_arrays_eq!(result, expected);
482 }
483
484 #[test]
485 fn test_empty_list() {
486 let list = ListViewArray::new(
487 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
488 buffer![0i32, 0i32, 0i32].into_array(),
489 buffer![0i32, 0i32, 0i32].into_array(),
490 Validity::AllValid,
491 );
492
493 let result = list
494 .clone()
495 .into_array()
496 .binary(list.into_array(), Operator::Eq)
497 .unwrap();
498 assert!(result.scalar_at(0).unwrap().is_valid());
499 assert!(result.scalar_at(1).unwrap().is_valid());
500 assert!(result.scalar_at(2).unwrap().is_valid());
501 }
502}