1use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::Field;
11use arrow_schema::SortOptions;
12use vortex_error::VortexResult;
13use vortex_error::vortex_err;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::array::VTable;
21use crate::arrays::Constant;
22use crate::arrays::ConstantArray;
23use crate::arrays::ScalarFn;
24use crate::arrays::scalar_fn::ExactScalarFn;
25use crate::arrays::scalar_fn::ScalarFnArrayExt;
26use crate::arrays::scalar_fn::ScalarFnArrayView;
27use crate::arrow::ArrowSessionExt;
28use crate::arrow::Datum;
29use crate::arrow::from_arrow_array_with_len;
30use crate::dtype::DType;
31use crate::dtype::Nullability;
32use crate::kernel::ExecuteParentKernel;
33use crate::scalar::Scalar;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::CompareOperator;
36
37pub trait CompareKernel: VTable {
43 fn compare(
44 lhs: ArrayView<'_, Self>,
45 rhs: &ArrayRef,
46 operator: CompareOperator,
47 ctx: &mut ExecutionCtx,
48 ) -> VortexResult<Option<ArrayRef>>;
49}
50
51#[derive(Default, Debug)]
57pub struct CompareExecuteAdaptor<V>(pub V);
58
59impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
60where
61 V: CompareKernel,
62{
63 type Parent = ExactScalarFn<Binary>;
64
65 fn execute_parent(
66 &self,
67 array: ArrayView<'_, V>,
68 parent: ScalarFnArrayView<'_, Binary>,
69 child_idx: usize,
70 ctx: &mut ExecutionCtx,
71 ) -> VortexResult<Option<ArrayRef>> {
72 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
74 return Ok(None);
75 };
76
77 let Some(scalar_fn_array) = parent.as_opt::<ScalarFn>() else {
79 return Ok(None);
80 };
81 let (cmp_op, other) = match child_idx {
84 0 => (cmp_op, scalar_fn_array.get_child(1)),
85 1 => (cmp_op.swap(), scalar_fn_array.get_child(0)),
86 _ => return Ok(None),
87 };
88
89 let len = array.len();
90 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
91
92 if len == 0 {
94 return Ok(Some(
95 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
96 ));
97 }
98
99 if other.as_constant().is_some_and(|s| s.is_null()) {
101 return Ok(Some(
102 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
103 ));
104 }
105
106 V::compare(array, other, cmp_op, ctx)
107 }
108}
109
110pub(crate) fn execute_compare(
115 lhs: &ArrayRef,
116 rhs: &ArrayRef,
117 op: CompareOperator,
118 ctx: &mut ExecutionCtx,
119) -> VortexResult<ArrayRef> {
120 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
121
122 if lhs.is_empty() {
123 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
124 }
125
126 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
127 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
128 if left_constant_null || right_constant_null {
129 return Ok(
130 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
131 );
132 }
133
134 if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
136 {
137 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
138 return Ok(ConstantArray::new(result, lhs.len()).into_array());
139 }
140
141 arrow_compare_arrays(lhs, rhs, op, ctx)
142}
143
144fn arrow_compare_arrays(
146 left: &ArrayRef,
147 right: &ArrayRef,
148 operator: CompareOperator,
149 ctx: &mut ExecutionCtx,
150) -> VortexResult<ArrayRef> {
151 assert_eq!(left.len(), right.len());
152
153 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
154
155 let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
158 let session = ctx.session().clone();
159 let lhs = session.arrow().execute_arrow(left.clone(), None, ctx)?;
160 let target_field = Field::new("", lhs.data_type().clone(), right.dtype().is_nullable());
161 let rhs = session
162 .arrow()
163 .execute_arrow(right.clone(), Some(&target_field), ctx)?;
164
165 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
166 } else {
167 let lhs = Datum::try_new(left, ctx)?;
169 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type(), ctx)?;
170
171 match operator {
172 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
173 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
174 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
175 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
176 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
177 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
178 }
179 };
180
181 from_arrow_array_with_len(&arrow_array, left.len(), nullable)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
185 if lhs.is_null() | rhs.is_null() {
186 return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
187 }
188
189 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
190
191 let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
193 vortex_err!(
194 "Cannot compare scalars with incompatible types: {} and {}",
195 lhs.dtype(),
196 rhs.dtype()
197 )
198 })?;
199
200 let b = match operator {
201 CompareOperator::Eq => ordering.is_eq(),
202 CompareOperator::NotEq => ordering.is_ne(),
203 CompareOperator::Gt => ordering.is_gt(),
204 CompareOperator::Gte => ordering.is_ge(),
205 CompareOperator::Lt => ordering.is_lt(),
206 CompareOperator::Lte => ordering.is_le(),
207 };
208
209 Ok(Scalar::bool(b, nullability))
210}
211
212pub fn compare_nested_arrow_arrays(
220 lhs: &dyn arrow_array::Array,
221 rhs: &dyn arrow_array::Array,
222 operator: CompareOperator,
223) -> VortexResult<BooleanArray> {
224 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
225
226 let cmp_fn = match operator {
227 CompareOperator::Eq => Ordering::is_eq,
228 CompareOperator::NotEq => Ordering::is_ne,
229 CompareOperator::Gt => Ordering::is_gt,
230 CompareOperator::Gte => Ordering::is_ge,
231 CompareOperator::Lt => Ordering::is_lt,
232 CompareOperator::Lte => Ordering::is_le,
233 };
234
235 let values = (0..lhs.len())
236 .map(|i| cmp_fn(compare_arrays_at(i, i)))
237 .collect();
238 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
239
240 Ok(BooleanArray::new(values, nulls))
241}
242
243#[cfg(test)]
244mod tests {
245 use std::sync::Arc;
246
247 use rstest::rstest;
248 use vortex_buffer::buffer;
249
250 use crate::ArrayRef;
251 use crate::IntoArray;
252 use crate::LEGACY_SESSION;
253 #[expect(deprecated)]
254 use crate::ToCanonical as _;
255 use crate::VortexSessionExecute;
256 use crate::arrays::BoolArray;
257 use crate::arrays::ListArray;
258 use crate::arrays::ListViewArray;
259 use crate::arrays::PrimitiveArray;
260 use crate::arrays::StructArray;
261 use crate::arrays::VarBinArray;
262 use crate::arrays::VarBinViewArray;
263 use crate::assert_arrays_eq;
264 use crate::builtins::ArrayBuiltins;
265 use crate::dtype::DType;
266 use crate::dtype::FieldName;
267 use crate::dtype::FieldNames;
268 use crate::dtype::Nullability;
269 use crate::dtype::PType;
270 use crate::extension::datetime::TimeUnit;
271 use crate::extension::datetime::Timestamp;
272 use crate::extension::datetime::TimestampOptions;
273 use crate::scalar::Scalar;
274 use crate::scalar_fn::fns::binary::compare::ConstantArray;
275 use crate::scalar_fn::fns::binary::scalar_cmp;
276 use crate::scalar_fn::fns::operators::CompareOperator;
277 use crate::scalar_fn::fns::operators::Operator;
278 use crate::test_harness::to_int_indices;
279 use crate::validity::Validity;
280
281 #[test]
282 fn test_bool_basic_comparisons() {
283 use vortex_buffer::BitBuffer;
284
285 let arr = BoolArray::new(
286 BitBuffer::from_iter([true, true, false, true, false]),
287 Validity::from_iter([false, true, true, true, true]),
288 );
289
290 #[expect(deprecated)]
291 let matches = arr
292 .clone()
293 .into_array()
294 .binary(arr.clone().into_array(), Operator::Eq)
295 .unwrap()
296 .to_bool();
297 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
298
299 #[expect(deprecated)]
300 let matches = arr
301 .clone()
302 .into_array()
303 .binary(arr.clone().into_array(), Operator::NotEq)
304 .unwrap()
305 .to_bool();
306 let empty: [u64; 0] = [];
307 assert_eq!(to_int_indices(matches).unwrap(), empty);
308
309 let other = BoolArray::new(
310 BitBuffer::from_iter([false, false, false, true, true]),
311 Validity::from_iter([false, true, true, true, true]),
312 );
313
314 #[expect(deprecated)]
315 let matches = arr
316 .clone()
317 .into_array()
318 .binary(other.clone().into_array(), Operator::Lte)
319 .unwrap()
320 .to_bool();
321 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
322
323 #[expect(deprecated)]
324 let matches = arr
325 .clone()
326 .into_array()
327 .binary(other.clone().into_array(), Operator::Lt)
328 .unwrap()
329 .to_bool();
330 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
331
332 #[expect(deprecated)]
333 let matches = other
334 .clone()
335 .into_array()
336 .binary(arr.clone().into_array(), Operator::Gte)
337 .unwrap()
338 .to_bool();
339 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
340
341 #[expect(deprecated)]
342 let matches = other
343 .into_array()
344 .binary(arr.into_array(), Operator::Gt)
345 .unwrap()
346 .to_bool();
347 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
348 }
349
350 #[test]
351 fn constant_compare() {
352 let left = ConstantArray::new(Scalar::from(2u32), 10);
353 let right = ConstantArray::new(Scalar::from(10u32), 10);
354
355 let result = left
356 .into_array()
357 .binary(right.into_array(), Operator::Gt)
358 .unwrap();
359 assert_eq!(result.len(), 10);
360 let scalar = result
361 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
362 .unwrap();
363 assert_eq!(scalar.as_bool().value(), Some(false));
364 }
365
366 #[rstest]
367 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
368 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
369 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
370 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
371 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
372 let res = left.binary(right, Operator::Eq).unwrap();
373 let expected = BoolArray::from_iter([true, true]);
374 assert_arrays_eq!(res, expected);
375 }
376
377 #[test]
378 fn test_list_array_comparison() {
379 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
380 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
381 let list1 = ListArray::try_new(
382 values1.into_array(),
383 offsets1.into_array(),
384 Validity::NonNullable,
385 )
386 .unwrap();
387
388 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
389 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
390 let list2 = ListArray::try_new(
391 values2.into_array(),
392 offsets2.into_array(),
393 Validity::NonNullable,
394 )
395 .unwrap();
396
397 let result = list1
398 .clone()
399 .into_array()
400 .binary(list2.clone().into_array(), Operator::Eq)
401 .unwrap();
402 let expected = BoolArray::from_iter([true, true, false]);
403 assert_arrays_eq!(result, expected);
404
405 let result = list1
406 .clone()
407 .into_array()
408 .binary(list2.clone().into_array(), Operator::NotEq)
409 .unwrap();
410 let expected = BoolArray::from_iter([false, false, true]);
411 assert_arrays_eq!(result, expected);
412
413 let result = list1
414 .into_array()
415 .binary(list2.into_array(), Operator::Lt)
416 .unwrap();
417 let expected = BoolArray::from_iter([false, false, true]);
418 assert_arrays_eq!(result, expected);
419 }
420
421 #[test]
422 fn test_list_array_constant_comparison() {
423 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
424 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
425 let list = ListArray::try_new(
426 values.into_array(),
427 offsets.into_array(),
428 Validity::NonNullable,
429 )
430 .unwrap();
431
432 let list_scalar = Scalar::list(
433 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
434 vec![3i32.into(), 4i32.into()],
435 Nullability::NonNullable,
436 );
437 let constant = ConstantArray::new(list_scalar, 3);
438
439 let result = list
440 .into_array()
441 .binary(constant.into_array(), Operator::Eq)
442 .unwrap();
443 let expected = BoolArray::from_iter([false, true, false]);
444 assert_arrays_eq!(result, expected);
445 }
446
447 #[test]
448 fn test_struct_array_comparison() {
449 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
450 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
451
452 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
453 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
454
455 let struct1 = StructArray::from_fields(&[
456 ("bool_col", bool_field1.into_array()),
457 ("int_col", int_field1.into_array()),
458 ])
459 .unwrap();
460
461 let struct2 = StructArray::from_fields(&[
462 ("bool_col", bool_field2.into_array()),
463 ("int_col", int_field2.into_array()),
464 ])
465 .unwrap();
466
467 let result = struct1
468 .clone()
469 .into_array()
470 .binary(struct2.clone().into_array(), Operator::Eq)
471 .unwrap();
472 let expected = BoolArray::from_iter([true, true, false]);
473 assert_arrays_eq!(result, expected);
474
475 let result = struct1
476 .into_array()
477 .binary(struct2.into_array(), Operator::Gt)
478 .unwrap();
479 let expected = BoolArray::from_iter([false, false, true]);
480 assert_arrays_eq!(result, expected);
481 }
482
483 #[test]
484 fn test_empty_struct_compare() {
485 let empty1 = StructArray::try_new(
486 FieldNames::from(Vec::<FieldName>::new()),
487 Vec::new(),
488 5,
489 Validity::NonNullable,
490 )
491 .unwrap();
492
493 let empty2 = StructArray::try_new(
494 FieldNames::from(Vec::<FieldName>::new()),
495 Vec::new(),
496 5,
497 Validity::NonNullable,
498 )
499 .unwrap();
500
501 let result = empty1
502 .into_array()
503 .binary(empty2.into_array(), Operator::Eq)
504 .unwrap();
505 let expected = BoolArray::from_iter([true, true, true, true, true]);
506 assert_arrays_eq!(result, expected);
507 }
508
509 #[test]
512 fn struct_compare_mixed_binary_encodings() {
513 let bin_field1 = VarBinArray::from(vec![
515 "apple".as_bytes(),
516 "banana".as_bytes(),
517 "cherry".as_bytes(),
518 ]);
519 let struct1 = StructArray::from_fields(&[("data", bin_field1.into_array())]).unwrap();
520
521 let bin_field2 = VarBinViewArray::from_iter_bin([
523 "apple".as_bytes(),
524 "banana".as_bytes(),
525 "durian".as_bytes(),
526 ]);
527 let struct2 = StructArray::from_fields(&[("data", bin_field2.into_array())]).unwrap();
528
529 let result = struct1
530 .into_array()
531 .binary(struct2.into_array(), Operator::Eq)
532 .unwrap();
533 let expected = BoolArray::from_iter([true, true, false]);
534 assert_arrays_eq!(result, expected);
535 }
536
537 #[test]
541 fn scalar_cmp_incompatible_extension_types_errors() {
542 let ms_scalar = Scalar::extension::<Timestamp>(
543 TimestampOptions {
544 unit: TimeUnit::Milliseconds,
545 tz: None,
546 },
547 Scalar::from(1704067200000i64),
548 );
549 let s_scalar = Scalar::extension::<Timestamp>(
550 TimestampOptions {
551 unit: TimeUnit::Seconds,
552 tz: None,
553 },
554 Scalar::from(1704067200i64),
555 );
556
557 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
559 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
560 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
561 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
562 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
563 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
564 }
565
566 #[test]
567 fn test_empty_list() {
568 let list = ListViewArray::new(
569 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
570 buffer![0i32, 0i32, 0i32].into_array(),
571 buffer![0i32, 0i32, 0i32].into_array(),
572 Validity::AllValid,
573 );
574
575 let result = list
576 .clone()
577 .into_array()
578 .binary(list.into_array(), Operator::Eq)
579 .unwrap();
580 assert!(
581 result
582 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
583 .unwrap()
584 .is_valid()
585 );
586 assert!(
587 result
588 .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
589 .unwrap()
590 .is_valid()
591 );
592 assert!(
593 result
594 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
595 .unwrap()
596 .is_valid()
597 );
598 }
599}