1use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::array::ArrayView;
19use crate::array::VTable;
20use crate::arrays::Constant;
21use crate::arrays::ConstantArray;
22use crate::arrays::ScalarFn;
23use crate::arrays::scalar_fn::ExactScalarFn;
24use crate::arrays::scalar_fn::ScalarFnArrayExt;
25use crate::arrays::scalar_fn::ScalarFnArrayView;
26use crate::arrow::ArrowArrayExecutor;
27use crate::arrow::Datum;
28use crate::arrow::from_arrow_array_with_len;
29use crate::dtype::DType;
30use crate::dtype::Nullability;
31use crate::kernel::ExecuteParentKernel;
32use crate::scalar::Scalar;
33use crate::scalar_fn::fns::binary::Binary;
34use crate::scalar_fn::fns::operators::CompareOperator;
35
36pub trait CompareKernel: VTable {
42 fn compare(
43 lhs: ArrayView<'_, Self>,
44 rhs: &ArrayRef,
45 operator: CompareOperator,
46 ctx: &mut ExecutionCtx,
47 ) -> VortexResult<Option<ArrayRef>>;
48}
49
50#[derive(Default, Debug)]
56pub struct CompareExecuteAdaptor<V>(pub V);
57
58impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
59where
60 V: CompareKernel,
61{
62 type Parent = ExactScalarFn<Binary>;
63
64 fn execute_parent(
65 &self,
66 array: ArrayView<'_, V>,
67 parent: ScalarFnArrayView<'_, Binary>,
68 child_idx: usize,
69 ctx: &mut ExecutionCtx,
70 ) -> VortexResult<Option<ArrayRef>> {
71 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
73 return Ok(None);
74 };
75
76 let Some(scalar_fn_array) = parent.as_opt::<ScalarFn>() else {
78 return Ok(None);
79 };
80 let (cmp_op, other) = match child_idx {
83 0 => (cmp_op, scalar_fn_array.get_child(1)),
84 1 => (cmp_op.swap(), scalar_fn_array.get_child(0)),
85 _ => return Ok(None),
86 };
87
88 let len = array.len();
89 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
90
91 if len == 0 {
93 return Ok(Some(
94 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
95 ));
96 }
97
98 if other.as_constant().is_some_and(|s| s.is_null()) {
100 return Ok(Some(
101 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
102 ));
103 }
104
105 V::compare(array, other, cmp_op, ctx)
106 }
107}
108
109pub(crate) fn execute_compare(
114 lhs: &ArrayRef,
115 rhs: &ArrayRef,
116 op: CompareOperator,
117 ctx: &mut ExecutionCtx,
118) -> VortexResult<ArrayRef> {
119 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
120
121 if lhs.is_empty() {
122 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
123 }
124
125 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
126 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
127 if left_constant_null || right_constant_null {
128 return Ok(
129 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
130 );
131 }
132
133 if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
135 {
136 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
137 return Ok(ConstantArray::new(result, lhs.len()).into_array());
138 }
139
140 arrow_compare_arrays(lhs, rhs, op, ctx)
141}
142
143fn arrow_compare_arrays(
145 left: &ArrayRef,
146 right: &ArrayRef,
147 operator: CompareOperator,
148 ctx: &mut ExecutionCtx,
149) -> VortexResult<ArrayRef> {
150 assert_eq!(left.len(), right.len());
151
152 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
153
154 let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
157 let rhs = right.clone().execute_arrow(None, ctx)?;
158 let lhs = left.clone().execute_arrow(Some(rhs.data_type()), ctx)?;
159
160 assert!(
161 lhs.data_type().equals_datatype(rhs.data_type()),
162 "lhs data_type: {}, rhs data_type: {}",
163 lhs.data_type(),
164 rhs.data_type()
165 );
166
167 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
168 } else {
169 let lhs = Datum::try_new(left, ctx)?;
171 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type(), ctx)?;
172
173 match operator {
174 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
175 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
176 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
177 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
178 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
179 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
180 }
181 };
182
183 from_arrow_array_with_len(&arrow_array, left.len(), nullable)
184}
185
186pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
187 if lhs.is_null() | rhs.is_null() {
188 return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
189 }
190
191 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
192
193 let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
195 vortex_err!(
196 "Cannot compare scalars with incompatible types: {} and {}",
197 lhs.dtype(),
198 rhs.dtype()
199 )
200 })?;
201
202 let b = match operator {
203 CompareOperator::Eq => ordering.is_eq(),
204 CompareOperator::NotEq => ordering.is_ne(),
205 CompareOperator::Gt => ordering.is_gt(),
206 CompareOperator::Gte => ordering.is_ge(),
207 CompareOperator::Lt => ordering.is_lt(),
208 CompareOperator::Lte => ordering.is_le(),
209 };
210
211 Ok(Scalar::bool(b, nullability))
212}
213
214pub fn compare_nested_arrow_arrays(
222 lhs: &dyn arrow_array::Array,
223 rhs: &dyn arrow_array::Array,
224 operator: CompareOperator,
225) -> VortexResult<BooleanArray> {
226 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
227
228 let cmp_fn = match operator {
229 CompareOperator::Eq => Ordering::is_eq,
230 CompareOperator::NotEq => Ordering::is_ne,
231 CompareOperator::Gt => Ordering::is_gt,
232 CompareOperator::Gte => Ordering::is_ge,
233 CompareOperator::Lt => Ordering::is_lt,
234 CompareOperator::Lte => Ordering::is_le,
235 };
236
237 let values = (0..lhs.len())
238 .map(|i| cmp_fn(compare_arrays_at(i, i)))
239 .collect();
240 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
241
242 Ok(BooleanArray::new(values, nulls))
243}
244
245#[cfg(test)]
246mod tests {
247 use std::sync::Arc;
248
249 use rstest::rstest;
250 use vortex_buffer::buffer;
251
252 use crate::ArrayRef;
253 use crate::IntoArray;
254 use crate::LEGACY_SESSION;
255 #[expect(deprecated)]
256 use crate::ToCanonical as _;
257 use crate::VortexSessionExecute;
258 use crate::arrays::BoolArray;
259 use crate::arrays::ListArray;
260 use crate::arrays::ListViewArray;
261 use crate::arrays::PrimitiveArray;
262 use crate::arrays::StructArray;
263 use crate::arrays::VarBinArray;
264 use crate::arrays::VarBinViewArray;
265 use crate::assert_arrays_eq;
266 use crate::builtins::ArrayBuiltins;
267 use crate::dtype::DType;
268 use crate::dtype::FieldName;
269 use crate::dtype::FieldNames;
270 use crate::dtype::Nullability;
271 use crate::dtype::PType;
272 use crate::extension::datetime::TimeUnit;
273 use crate::extension::datetime::Timestamp;
274 use crate::extension::datetime::TimestampOptions;
275 use crate::scalar::Scalar;
276 use crate::scalar_fn::fns::binary::compare::ConstantArray;
277 use crate::scalar_fn::fns::binary::scalar_cmp;
278 use crate::scalar_fn::fns::operators::CompareOperator;
279 use crate::scalar_fn::fns::operators::Operator;
280 use crate::test_harness::to_int_indices;
281 use crate::validity::Validity;
282
283 #[test]
284 fn test_bool_basic_comparisons() {
285 use vortex_buffer::BitBuffer;
286
287 let arr = BoolArray::new(
288 BitBuffer::from_iter([true, true, false, true, false]),
289 Validity::from_iter([false, true, true, true, true]),
290 );
291
292 #[expect(deprecated)]
293 let matches = arr
294 .clone()
295 .into_array()
296 .binary(arr.clone().into_array(), Operator::Eq)
297 .unwrap()
298 .to_bool();
299 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
300
301 #[expect(deprecated)]
302 let matches = arr
303 .clone()
304 .into_array()
305 .binary(arr.clone().into_array(), Operator::NotEq)
306 .unwrap()
307 .to_bool();
308 let empty: [u64; 0] = [];
309 assert_eq!(to_int_indices(matches).unwrap(), empty);
310
311 let other = BoolArray::new(
312 BitBuffer::from_iter([false, false, false, true, true]),
313 Validity::from_iter([false, true, true, true, true]),
314 );
315
316 #[expect(deprecated)]
317 let matches = arr
318 .clone()
319 .into_array()
320 .binary(other.clone().into_array(), Operator::Lte)
321 .unwrap()
322 .to_bool();
323 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
324
325 #[expect(deprecated)]
326 let matches = arr
327 .clone()
328 .into_array()
329 .binary(other.clone().into_array(), Operator::Lt)
330 .unwrap()
331 .to_bool();
332 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
333
334 #[expect(deprecated)]
335 let matches = other
336 .clone()
337 .into_array()
338 .binary(arr.clone().into_array(), Operator::Gte)
339 .unwrap()
340 .to_bool();
341 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
342
343 #[expect(deprecated)]
344 let matches = other
345 .into_array()
346 .binary(arr.into_array(), Operator::Gt)
347 .unwrap()
348 .to_bool();
349 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
350 }
351
352 #[test]
353 fn constant_compare() {
354 let left = ConstantArray::new(Scalar::from(2u32), 10);
355 let right = ConstantArray::new(Scalar::from(10u32), 10);
356
357 let result = left
358 .into_array()
359 .binary(right.into_array(), Operator::Gt)
360 .unwrap();
361 assert_eq!(result.len(), 10);
362 let scalar = result
363 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
364 .unwrap();
365 assert_eq!(scalar.as_bool().value(), Some(false));
366 }
367
368 #[rstest]
369 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
370 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
371 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
372 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
373 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
374 let res = left.binary(right, Operator::Eq).unwrap();
375 let expected = BoolArray::from_iter([true, true]);
376 assert_arrays_eq!(res, expected);
377 }
378
379 #[ignore = "Arrow's ListView cannot be compared"]
380 #[test]
381 fn test_list_array_comparison() {
382 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
383 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
384 let list1 = ListArray::try_new(
385 values1.into_array(),
386 offsets1.into_array(),
387 Validity::NonNullable,
388 )
389 .unwrap();
390
391 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
392 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
393 let list2 = ListArray::try_new(
394 values2.into_array(),
395 offsets2.into_array(),
396 Validity::NonNullable,
397 )
398 .unwrap();
399
400 let result = list1
401 .clone()
402 .into_array()
403 .binary(list2.clone().into_array(), Operator::Eq)
404 .unwrap();
405 let expected = BoolArray::from_iter([true, true, false]);
406 assert_arrays_eq!(result, expected);
407
408 let result = list1
409 .clone()
410 .into_array()
411 .binary(list2.clone().into_array(), Operator::NotEq)
412 .unwrap();
413 let expected = BoolArray::from_iter([false, false, true]);
414 assert_arrays_eq!(result, expected);
415
416 let result = list1
417 .into_array()
418 .binary(list2.into_array(), Operator::Lt)
419 .unwrap();
420 let expected = BoolArray::from_iter([false, false, true]);
421 assert_arrays_eq!(result, expected);
422 }
423
424 #[ignore = "Arrow's ListView cannot be compared"]
425 #[test]
426 fn test_list_array_constant_comparison() {
427 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
428 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
429 let list = ListArray::try_new(
430 values.into_array(),
431 offsets.into_array(),
432 Validity::NonNullable,
433 )
434 .unwrap();
435
436 let list_scalar = Scalar::list(
437 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
438 vec![3i32.into(), 4i32.into()],
439 Nullability::NonNullable,
440 );
441 let constant = ConstantArray::new(list_scalar, 3);
442
443 let result = list
444 .into_array()
445 .binary(constant.into_array(), Operator::Eq)
446 .unwrap();
447 let expected = BoolArray::from_iter([false, true, false]);
448 assert_arrays_eq!(result, expected);
449 }
450
451 #[test]
452 fn test_struct_array_comparison() {
453 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
454 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
455
456 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
457 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
458
459 let struct1 = StructArray::from_fields(&[
460 ("bool_col", bool_field1.into_array()),
461 ("int_col", int_field1.into_array()),
462 ])
463 .unwrap();
464
465 let struct2 = StructArray::from_fields(&[
466 ("bool_col", bool_field2.into_array()),
467 ("int_col", int_field2.into_array()),
468 ])
469 .unwrap();
470
471 let result = struct1
472 .clone()
473 .into_array()
474 .binary(struct2.clone().into_array(), Operator::Eq)
475 .unwrap();
476 let expected = BoolArray::from_iter([true, true, false]);
477 assert_arrays_eq!(result, expected);
478
479 let result = struct1
480 .into_array()
481 .binary(struct2.into_array(), Operator::Gt)
482 .unwrap();
483 let expected = BoolArray::from_iter([false, false, true]);
484 assert_arrays_eq!(result, expected);
485 }
486
487 #[test]
488 fn test_empty_struct_compare() {
489 let empty1 = StructArray::try_new(
490 FieldNames::from(Vec::<FieldName>::new()),
491 Vec::new(),
492 5,
493 Validity::NonNullable,
494 )
495 .unwrap();
496
497 let empty2 = StructArray::try_new(
498 FieldNames::from(Vec::<FieldName>::new()),
499 Vec::new(),
500 5,
501 Validity::NonNullable,
502 )
503 .unwrap();
504
505 let result = empty1
506 .into_array()
507 .binary(empty2.into_array(), Operator::Eq)
508 .unwrap();
509 let expected = BoolArray::from_iter([true, true, true, true, true]);
510 assert_arrays_eq!(result, expected);
511 }
512
513 #[test]
517 fn scalar_cmp_incompatible_extension_types_errors() {
518 let ms_scalar = Scalar::extension::<Timestamp>(
519 TimestampOptions {
520 unit: TimeUnit::Milliseconds,
521 tz: None,
522 },
523 Scalar::from(1704067200000i64),
524 );
525 let s_scalar = Scalar::extension::<Timestamp>(
526 TimestampOptions {
527 unit: TimeUnit::Seconds,
528 tz: None,
529 },
530 Scalar::from(1704067200i64),
531 );
532
533 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
535 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
536 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
537 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
538 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
539 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
540 }
541
542 #[test]
543 fn test_empty_list() {
544 let list = ListViewArray::new(
545 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
546 buffer![0i32, 0i32, 0i32].into_array(),
547 buffer![0i32, 0i32, 0i32].into_array(),
548 Validity::AllValid,
549 );
550
551 let result = list
552 .clone()
553 .into_array()
554 .binary(list.into_array(), Operator::Eq)
555 .unwrap();
556 assert!(
557 result
558 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
559 .unwrap()
560 .is_valid()
561 );
562 assert!(
563 result
564 .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
565 .unwrap()
566 .is_valid()
567 );
568 assert!(
569 result
570 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
571 .unwrap()
572 .is_valid()
573 );
574 }
575}