1use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::Constant;
18use crate::arrays::ConstantArray;
19use crate::arrays::ScalarFnVTable;
20use crate::arrays::scalar_fn::ExactScalarFn;
21use crate::arrays::scalar_fn::ScalarFnArrayView;
22use crate::arrow::Datum;
23use crate::arrow::IntoArrowArray;
24use crate::arrow::from_arrow_array_with_len;
25use crate::dtype::DType;
26use crate::dtype::Nullability;
27use crate::kernel::ExecuteParentKernel;
28use crate::scalar::Scalar;
29use crate::scalar_fn::fns::binary::Binary;
30use crate::scalar_fn::fns::operators::CompareOperator;
31use crate::vtable::VTable;
32
33pub trait CompareKernel: VTable {
39 fn compare(
40 lhs: &Self::Array,
41 rhs: &ArrayRef,
42 operator: CompareOperator,
43 ctx: &mut ExecutionCtx,
44 ) -> VortexResult<Option<ArrayRef>>;
45}
46
47#[derive(Default, Debug)]
53pub struct CompareExecuteAdaptor<V>(pub V);
54
55impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
56where
57 V: CompareKernel,
58{
59 type Parent = ExactScalarFn<Binary>;
60
61 fn execute_parent(
62 &self,
63 array: &V::Array,
64 parent: ScalarFnArrayView<'_, Binary>,
65 child_idx: usize,
66 ctx: &mut ExecutionCtx,
67 ) -> VortexResult<Option<ArrayRef>> {
68 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
70 return Ok(None);
71 };
72
73 let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
75 return Ok(None);
76 };
77 let children = scalar_fn_array.children();
78
79 let (cmp_op, other) = match child_idx {
82 0 => (cmp_op, &children[1]),
83 1 => (cmp_op.swap(), &children[0]),
84 _ => return Ok(None),
85 };
86
87 let len = array.len();
88 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
89
90 if len == 0 {
92 return Ok(Some(
93 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
94 ));
95 }
96
97 if other.as_constant().is_some_and(|s| s.is_null()) {
99 return Ok(Some(
100 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
101 ));
102 }
103
104 V::compare(array, other, cmp_op, ctx)
105 }
106}
107
108pub(crate) fn execute_compare(
113 lhs: &ArrayRef,
114 rhs: &ArrayRef,
115 op: CompareOperator,
116) -> VortexResult<ArrayRef> {
117 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
118
119 if lhs.is_empty() {
120 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
121 }
122
123 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
124 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
125 if left_constant_null || right_constant_null {
126 return Ok(
127 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
128 );
129 }
130
131 if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
133 {
134 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
135 return Ok(ConstantArray::new(result, lhs.len()).into_array());
136 }
137
138 arrow_compare_arrays(lhs, rhs, op)
139}
140
141fn arrow_compare_arrays(
143 left: &ArrayRef,
144 right: &ArrayRef,
145 operator: CompareOperator,
146) -> VortexResult<ArrayRef> {
147 assert_eq!(left.len(), right.len());
148
149 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
150
151 let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
154 let rhs = right.to_array().into_arrow_preferred()?;
155 let lhs = left.to_array().into_arrow(rhs.data_type())?;
156
157 assert!(
158 lhs.data_type().equals_datatype(rhs.data_type()),
159 "lhs data_type: {}, rhs data_type: {}",
160 lhs.data_type(),
161 rhs.data_type()
162 );
163
164 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
165 } else {
166 let lhs = Datum::try_new(left)?;
168 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
169
170 match operator {
171 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
172 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
173 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
174 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
175 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
176 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
177 }
178 };
179 from_arrow_array_with_len(&array, left.len(), nullable)
180}
181
182pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
183 if lhs.is_null() | rhs.is_null() {
184 Scalar::null(DType::Bool(Nullability::Nullable))
185 } else {
186 let b = match operator {
187 CompareOperator::Eq => lhs == rhs,
188 CompareOperator::NotEq => lhs != rhs,
189 CompareOperator::Gt => lhs > rhs,
190 CompareOperator::Gte => lhs >= rhs,
191 CompareOperator::Lt => lhs < rhs,
192 CompareOperator::Lte => lhs <= rhs,
193 };
194
195 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
196 }
197}
198
199pub fn compare_nested_arrow_arrays(
207 lhs: &dyn arrow_array::Array,
208 rhs: &dyn arrow_array::Array,
209 operator: CompareOperator,
210) -> VortexResult<BooleanArray> {
211 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
212
213 let cmp_fn = match operator {
214 CompareOperator::Eq => Ordering::is_eq,
215 CompareOperator::NotEq => Ordering::is_ne,
216 CompareOperator::Gt => Ordering::is_gt,
217 CompareOperator::Gte => Ordering::is_ge,
218 CompareOperator::Lt => Ordering::is_lt,
219 CompareOperator::Lte => Ordering::is_le,
220 };
221
222 let values = (0..lhs.len())
223 .map(|i| cmp_fn(compare_arrays_at(i, i)))
224 .collect();
225 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
226
227 Ok(BooleanArray::new(values, nulls))
228}
229
230#[cfg(test)]
231mod tests {
232 use std::sync::Arc;
233
234 use rstest::rstest;
235 use vortex_buffer::buffer;
236
237 use crate::ArrayRef;
238 use crate::IntoArray;
239 use crate::ToCanonical;
240 use crate::arrays::BoolArray;
241 use crate::arrays::ListArray;
242 use crate::arrays::ListViewArray;
243 use crate::arrays::PrimitiveArray;
244 use crate::arrays::StructArray;
245 use crate::arrays::VarBinArray;
246 use crate::arrays::VarBinViewArray;
247 use crate::assert_arrays_eq;
248 use crate::builtins::ArrayBuiltins;
249 use crate::dtype::DType;
250 use crate::dtype::FieldName;
251 use crate::dtype::FieldNames;
252 use crate::dtype::Nullability;
253 use crate::dtype::PType;
254 use crate::scalar::Scalar;
255 use crate::scalar_fn::fns::binary::compare::ConstantArray;
256 use crate::scalar_fn::fns::operators::Operator;
257 use crate::test_harness::to_int_indices;
258 use crate::validity::Validity;
259
260 #[test]
261 fn test_bool_basic_comparisons() {
262 use vortex_buffer::BitBuffer;
263
264 let arr = BoolArray::new(
265 BitBuffer::from_iter([true, true, false, true, false]),
266 Validity::from_iter([false, true, true, true, true]),
267 );
268
269 let matches = arr
270 .clone()
271 .into_array()
272 .binary(arr.clone().into_array(), Operator::Eq)
273 .unwrap()
274 .to_bool();
275 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
276
277 let matches = arr
278 .clone()
279 .into_array()
280 .binary(arr.clone().into_array(), Operator::NotEq)
281 .unwrap()
282 .to_bool();
283 let empty: [u64; 0] = [];
284 assert_eq!(to_int_indices(matches).unwrap(), empty);
285
286 let other = BoolArray::new(
287 BitBuffer::from_iter([false, false, false, true, true]),
288 Validity::from_iter([false, true, true, true, true]),
289 );
290
291 let matches = arr
292 .clone()
293 .into_array()
294 .binary(other.clone().into_array(), Operator::Lte)
295 .unwrap()
296 .to_bool();
297 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
298
299 let matches = arr
300 .clone()
301 .into_array()
302 .binary(other.clone().into_array(), Operator::Lt)
303 .unwrap()
304 .to_bool();
305 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
306
307 let matches = other
308 .clone()
309 .into_array()
310 .binary(arr.clone().into_array(), Operator::Gte)
311 .unwrap()
312 .to_bool();
313 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
314
315 let matches = other
316 .into_array()
317 .binary(arr.into_array(), Operator::Gt)
318 .unwrap()
319 .to_bool();
320 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
321 }
322
323 #[test]
324 fn constant_compare() {
325 let left = ConstantArray::new(Scalar::from(2u32), 10);
326 let right = ConstantArray::new(Scalar::from(10u32), 10);
327
328 let result = left
329 .into_array()
330 .binary(right.into_array(), Operator::Gt)
331 .unwrap();
332 assert_eq!(result.len(), 10);
333 let scalar = result.scalar_at(0).unwrap();
334 assert_eq!(scalar.as_bool().value(), Some(false));
335 }
336
337 #[rstest]
338 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
339 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
340 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
341 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
342 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
343 let res = left.binary(right, Operator::Eq).unwrap();
344 let expected = BoolArray::from_iter([true, true]);
345 assert_arrays_eq!(res, expected);
346 }
347
348 #[ignore = "Arrow's ListView cannot be compared"]
349 #[test]
350 fn test_list_array_comparison() {
351 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
352 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
353 let list1 = ListArray::try_new(
354 values1.into_array(),
355 offsets1.into_array(),
356 Validity::NonNullable,
357 )
358 .unwrap();
359
360 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
361 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
362 let list2 = ListArray::try_new(
363 values2.into_array(),
364 offsets2.into_array(),
365 Validity::NonNullable,
366 )
367 .unwrap();
368
369 let result = list1
370 .clone()
371 .into_array()
372 .binary(list2.clone().into_array(), Operator::Eq)
373 .unwrap();
374 let expected = BoolArray::from_iter([true, true, false]);
375 assert_arrays_eq!(result, expected);
376
377 let result = list1
378 .clone()
379 .into_array()
380 .binary(list2.clone().into_array(), Operator::NotEq)
381 .unwrap();
382 let expected = BoolArray::from_iter([false, false, true]);
383 assert_arrays_eq!(result, expected);
384
385 let result = list1
386 .into_array()
387 .binary(list2.into_array(), Operator::Lt)
388 .unwrap();
389 let expected = BoolArray::from_iter([false, false, true]);
390 assert_arrays_eq!(result, expected);
391 }
392
393 #[ignore = "Arrow's ListView cannot be compared"]
394 #[test]
395 fn test_list_array_constant_comparison() {
396 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
397 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
398 let list = ListArray::try_new(
399 values.into_array(),
400 offsets.into_array(),
401 Validity::NonNullable,
402 )
403 .unwrap();
404
405 let list_scalar = Scalar::list(
406 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
407 vec![3i32.into(), 4i32.into()],
408 Nullability::NonNullable,
409 );
410 let constant = ConstantArray::new(list_scalar, 3);
411
412 let result = list
413 .into_array()
414 .binary(constant.into_array(), Operator::Eq)
415 .unwrap();
416 let expected = BoolArray::from_iter([false, true, false]);
417 assert_arrays_eq!(result, expected);
418 }
419
420 #[test]
421 fn test_struct_array_comparison() {
422 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
423 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
424
425 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
426 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
427
428 let struct1 = StructArray::from_fields(&[
429 ("bool_col", bool_field1.into_array()),
430 ("int_col", int_field1.into_array()),
431 ])
432 .unwrap();
433
434 let struct2 = StructArray::from_fields(&[
435 ("bool_col", bool_field2.into_array()),
436 ("int_col", int_field2.into_array()),
437 ])
438 .unwrap();
439
440 let result = struct1
441 .clone()
442 .into_array()
443 .binary(struct2.clone().into_array(), Operator::Eq)
444 .unwrap();
445 let expected = BoolArray::from_iter([true, true, false]);
446 assert_arrays_eq!(result, expected);
447
448 let result = struct1
449 .into_array()
450 .binary(struct2.into_array(), Operator::Gt)
451 .unwrap();
452 let expected = BoolArray::from_iter([false, false, true]);
453 assert_arrays_eq!(result, expected);
454 }
455
456 #[test]
457 fn test_empty_struct_compare() {
458 let empty1 = StructArray::try_new(
459 FieldNames::from(Vec::<FieldName>::new()),
460 Vec::new(),
461 5,
462 Validity::NonNullable,
463 )
464 .unwrap();
465
466 let empty2 = StructArray::try_new(
467 FieldNames::from(Vec::<FieldName>::new()),
468 Vec::new(),
469 5,
470 Validity::NonNullable,
471 )
472 .unwrap();
473
474 let result = empty1
475 .into_array()
476 .binary(empty2.into_array(), Operator::Eq)
477 .unwrap();
478 let expected = BoolArray::from_iter([true, true, true, true, true]);
479 assert_arrays_eq!(result, expected);
480 }
481
482 #[test]
483 fn test_empty_list() {
484 let list = ListViewArray::new(
485 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
486 buffer![0i32, 0i32, 0i32].into_array(),
487 buffer![0i32, 0i32, 0i32].into_array(),
488 Validity::AllValid,
489 );
490
491 let result = list
492 .clone()
493 .into_array()
494 .binary(list.into_array(), Operator::Eq)
495 .unwrap();
496 assert!(result.scalar_at(0).unwrap().is_valid());
497 assert!(result.scalar_at(1).unwrap().is_valid());
498 assert!(result.scalar_at(2).unwrap().is_valid());
499 }
500}