1use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::Field;
11use arrow_schema::SortOptions;
12use vortex_error::VortexResult;
13use vortex_error::vortex_err;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::array::VTable;
21use crate::arrays::Constant;
22use crate::arrays::ConstantArray;
23use crate::arrays::ScalarFn;
24use crate::arrays::scalar_fn::ExactScalarFn;
25use crate::arrays::scalar_fn::ScalarFnArrayExt;
26use crate::arrays::scalar_fn::ScalarFnArrayView;
27use crate::arrow::ArrowSessionExt;
28use crate::arrow::Datum;
29use crate::arrow::from_arrow_columnar;
30use crate::dtype::DType;
31use crate::dtype::Nullability;
32use crate::kernel::ExecuteParentKernel;
33use crate::scalar::Scalar;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::CompareOperator;
36
37pub trait CompareKernel: VTable {
43 fn compare(
44 lhs: ArrayView<'_, Self>,
45 rhs: &ArrayRef,
46 operator: CompareOperator,
47 ctx: &mut ExecutionCtx,
48 ) -> VortexResult<Option<ArrayRef>>;
49}
50
51#[derive(Default, Debug)]
57pub struct CompareExecuteAdaptor<V>(pub V);
58
59impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
60where
61 V: CompareKernel,
62{
63 type Parent = ExactScalarFn<Binary>;
64
65 fn execute_parent(
66 &self,
67 array: ArrayView<'_, V>,
68 parent: ScalarFnArrayView<'_, Binary>,
69 child_idx: usize,
70 ctx: &mut ExecutionCtx,
71 ) -> VortexResult<Option<ArrayRef>> {
72 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
74 return Ok(None);
75 };
76
77 let Some(scalar_fn_array) = parent.as_opt::<ScalarFn>() else {
79 return Ok(None);
80 };
81 let (cmp_op, other) = match child_idx {
84 0 => (cmp_op, scalar_fn_array.get_child(1)),
85 1 => (cmp_op.swap(), scalar_fn_array.get_child(0)),
86 _ => return Ok(None),
87 };
88
89 let len = array.len();
90 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
91
92 if len == 0 {
94 return Ok(Some(
95 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
96 ));
97 }
98
99 if other.as_constant().is_some_and(|s| s.is_null()) {
101 return Ok(Some(
102 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
103 ));
104 }
105
106 V::compare(array, other, cmp_op, ctx)
107 }
108}
109
110pub(crate) fn execute_compare(
115 lhs: &ArrayRef,
116 rhs: &ArrayRef,
117 op: CompareOperator,
118 ctx: &mut ExecutionCtx,
119) -> VortexResult<ArrayRef> {
120 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
121
122 if lhs.is_empty() {
123 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
124 }
125
126 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
127 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
128 if left_constant_null || right_constant_null {
129 return Ok(
130 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
131 );
132 }
133
134 if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
136 {
137 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
138 return Ok(ConstantArray::new(result, lhs.len()).into_array());
139 }
140
141 arrow_compare_arrays(lhs, rhs, op, ctx)
142}
143
144fn arrow_compare_arrays(
146 left: &ArrayRef,
147 right: &ArrayRef,
148 operator: CompareOperator,
149 ctx: &mut ExecutionCtx,
150) -> VortexResult<ArrayRef> {
151 assert_eq!(left.len(), right.len());
152
153 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
154
155 let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
158 let session = ctx.session().clone();
159 let lhs = session.arrow().execute_arrow(left.clone(), None, ctx)?;
160 let target_field = Field::new("", lhs.data_type().clone(), right.dtype().is_nullable());
161 let rhs = session
162 .arrow()
163 .execute_arrow(right.clone(), Some(&target_field), ctx)?;
164
165 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
166 } else {
167 let lhs = Datum::try_new(left, ctx)?;
169 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type(), ctx)?;
170
171 match operator {
172 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
173 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
174 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
175 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
176 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
177 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
178 }
179 };
180
181 from_arrow_columnar(&arrow_array, left.len(), nullable, ctx)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
185 if lhs.is_null() | rhs.is_null() {
186 return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
187 }
188
189 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
190
191 let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
193 vortex_err!(
194 "Cannot compare scalars with incompatible types: {} and {}",
195 lhs.dtype(),
196 rhs.dtype()
197 )
198 })?;
199
200 let b = match operator {
201 CompareOperator::Eq => ordering.is_eq(),
202 CompareOperator::NotEq => ordering.is_ne(),
203 CompareOperator::Gt => ordering.is_gt(),
204 CompareOperator::Gte => ordering.is_ge(),
205 CompareOperator::Lt => ordering.is_lt(),
206 CompareOperator::Lte => ordering.is_le(),
207 };
208
209 Ok(Scalar::bool(b, nullability))
210}
211
212pub fn compare_nested_arrow_arrays(
220 lhs: &dyn arrow_array::Array,
221 rhs: &dyn arrow_array::Array,
222 operator: CompareOperator,
223) -> VortexResult<BooleanArray> {
224 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
225
226 let cmp_fn = match operator {
227 CompareOperator::Eq => Ordering::is_eq,
228 CompareOperator::NotEq => Ordering::is_ne,
229 CompareOperator::Gt => Ordering::is_gt,
230 CompareOperator::Gte => Ordering::is_ge,
231 CompareOperator::Lt => Ordering::is_lt,
232 CompareOperator::Lte => Ordering::is_le,
233 };
234
235 let values = (0..lhs.len())
236 .map(|i| cmp_fn(compare_arrays_at(i, i)))
237 .collect();
238 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
239
240 Ok(BooleanArray::new(values, nulls))
241}
242
243#[cfg(test)]
244mod tests {
245 use std::sync::Arc;
246
247 use rstest::rstest;
248 use vortex_buffer::BitBuffer;
249 use vortex_buffer::buffer;
250 use vortex_error::VortexExpect;
251
252 use crate::ArrayRef;
253 use crate::IntoArray;
254 use crate::VortexSessionExecute;
255 use crate::array_session;
256 use crate::arrays::BoolArray;
257 use crate::arrays::ListArray;
258 use crate::arrays::ListViewArray;
259 use crate::arrays::PrimitiveArray;
260 use crate::arrays::StructArray;
261 use crate::arrays::VarBinArray;
262 use crate::arrays::VarBinViewArray;
263 use crate::assert_arrays_eq;
264 use crate::builtins::ArrayBuiltins;
265 use crate::dtype::DType;
266 use crate::dtype::FieldName;
267 use crate::dtype::FieldNames;
268 use crate::dtype::Nullability;
269 use crate::dtype::PType;
270 use crate::extension::datetime::TimeUnit;
271 use crate::extension::datetime::Timestamp;
272 use crate::extension::datetime::TimestampOptions;
273 use crate::scalar::Scalar;
274 use crate::scalar_fn::fns::binary::compare::ConstantArray;
275 use crate::scalar_fn::fns::binary::scalar_cmp;
276 use crate::scalar_fn::fns::operators::CompareOperator;
277 use crate::scalar_fn::fns::operators::Operator;
278 use crate::test_harness::to_int_indices;
279 use crate::validity::Validity;
280
281 #[test]
282 fn test_bool_basic_comparisons() {
283 let ctx = &mut array_session().create_execution_ctx();
284 let arr = BoolArray::new(
285 BitBuffer::from_iter([true, true, false, true, false]),
286 Validity::from_iter([false, true, true, true, true]),
287 );
288
289 let matches = arr
290 .clone()
291 .into_array()
292 .binary(arr.clone().into_array(), Operator::Eq)
293 .unwrap()
294 .execute::<BoolArray>(ctx)
295 .vortex_expect("must be a bool array");
296 assert_eq!(to_int_indices(matches, ctx).unwrap(), [1u64, 2, 3, 4]);
297
298 let matches = arr
299 .clone()
300 .into_array()
301 .binary(arr.clone().into_array(), Operator::NotEq)
302 .unwrap()
303 .execute::<BoolArray>(ctx)
304 .vortex_expect("must be a bool array");
305 let empty: [u64; 0] = [];
306 assert_eq!(to_int_indices(matches, ctx).unwrap(), empty);
307
308 let other = BoolArray::new(
309 BitBuffer::from_iter([false, false, false, true, true]),
310 Validity::from_iter([false, true, true, true, true]),
311 );
312
313 let matches = arr
314 .clone()
315 .into_array()
316 .binary(other.clone().into_array(), Operator::Lte)
317 .unwrap()
318 .execute::<BoolArray>(ctx)
319 .vortex_expect("must be a bool array");
320 assert_eq!(to_int_indices(matches, ctx).unwrap(), [2u64, 3, 4]);
321
322 let matches = arr
323 .clone()
324 .into_array()
325 .binary(other.clone().into_array(), Operator::Lt)
326 .unwrap()
327 .execute::<BoolArray>(ctx)
328 .vortex_expect("must be a bool array");
329 assert_eq!(to_int_indices(matches, ctx).unwrap(), [4u64]);
330
331 let matches = other
332 .clone()
333 .into_array()
334 .binary(arr.clone().into_array(), Operator::Gte)
335 .unwrap()
336 .execute::<BoolArray>(ctx)
337 .vortex_expect("must be a bool array");
338 assert_eq!(to_int_indices(matches, ctx).unwrap(), [2u64, 3, 4]);
339
340 let matches = other
341 .into_array()
342 .binary(arr.into_array(), Operator::Gt)
343 .unwrap()
344 .execute::<BoolArray>(ctx)
345 .vortex_expect("must be a bool array");
346 assert_eq!(to_int_indices(matches, ctx).unwrap(), [4u64]);
347 }
348
349 #[test]
350 fn constant_compare() {
351 let left = ConstantArray::new(Scalar::from(2u32), 10);
352 let right = ConstantArray::new(Scalar::from(10u32), 10);
353
354 let result = left
355 .into_array()
356 .binary(right.into_array(), Operator::Gt)
357 .unwrap();
358 assert_eq!(result.len(), 10);
359 let scalar = result
360 .execute_scalar(0, &mut array_session().create_execution_ctx())
361 .unwrap();
362 assert_eq!(scalar.as_bool().value(), Some(false));
363 }
364
365 #[rstest]
366 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
367 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
368 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
369 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
370 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
371 let mut ctx = array_session().create_execution_ctx();
372 let res = left.binary(right, Operator::Eq).unwrap();
373 let expected = BoolArray::from_iter([true, true]);
374 assert_arrays_eq!(res, expected, &mut ctx);
375 }
376
377 #[test]
378 fn test_list_array_comparison() {
379 let mut ctx = array_session().create_execution_ctx();
380 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
381 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
382 let list1 = ListArray::try_new(
383 values1.into_array(),
384 offsets1.into_array(),
385 Validity::NonNullable,
386 )
387 .unwrap();
388
389 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
390 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
391 let list2 = ListArray::try_new(
392 values2.into_array(),
393 offsets2.into_array(),
394 Validity::NonNullable,
395 )
396 .unwrap();
397
398 let result = list1
399 .clone()
400 .into_array()
401 .binary(list2.clone().into_array(), Operator::Eq)
402 .unwrap();
403 let expected = BoolArray::from_iter([true, true, false]);
404 assert_arrays_eq!(result, expected, &mut ctx);
405
406 let result = list1
407 .clone()
408 .into_array()
409 .binary(list2.clone().into_array(), Operator::NotEq)
410 .unwrap();
411 let expected = BoolArray::from_iter([false, false, true]);
412 assert_arrays_eq!(result, expected, &mut ctx);
413
414 let result = list1
415 .into_array()
416 .binary(list2.into_array(), Operator::Lt)
417 .unwrap();
418 let expected = BoolArray::from_iter([false, false, true]);
419 assert_arrays_eq!(result, expected, &mut ctx);
420 }
421
422 #[test]
423 fn test_list_array_constant_comparison() {
424 let mut ctx = array_session().create_execution_ctx();
425 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
426 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
427 let list = ListArray::try_new(
428 values.into_array(),
429 offsets.into_array(),
430 Validity::NonNullable,
431 )
432 .unwrap();
433
434 let list_scalar = Scalar::list(
435 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
436 vec![3i32.into(), 4i32.into()],
437 Nullability::NonNullable,
438 );
439 let constant = ConstantArray::new(list_scalar, 3);
440
441 let result = list
442 .into_array()
443 .binary(constant.into_array(), Operator::Eq)
444 .unwrap();
445 let expected = BoolArray::from_iter([false, true, false]);
446 assert_arrays_eq!(result, expected, &mut ctx);
447 }
448
449 #[test]
450 fn test_struct_array_comparison() {
451 let mut ctx = array_session().create_execution_ctx();
452 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
453 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
454
455 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
456 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
457
458 let struct1 = StructArray::from_fields(&[
459 ("bool_col", bool_field1.into_array()),
460 ("int_col", int_field1.into_array()),
461 ])
462 .unwrap();
463
464 let struct2 = StructArray::from_fields(&[
465 ("bool_col", bool_field2.into_array()),
466 ("int_col", int_field2.into_array()),
467 ])
468 .unwrap();
469
470 let result = struct1
471 .clone()
472 .into_array()
473 .binary(struct2.clone().into_array(), Operator::Eq)
474 .unwrap();
475 let expected = BoolArray::from_iter([true, true, false]);
476 assert_arrays_eq!(result, expected, &mut ctx);
477
478 let result = struct1
479 .into_array()
480 .binary(struct2.into_array(), Operator::Gt)
481 .unwrap();
482 let expected = BoolArray::from_iter([false, false, true]);
483 assert_arrays_eq!(result, expected, &mut ctx);
484 }
485
486 #[test]
487 fn test_empty_struct_compare() {
488 let mut ctx = array_session().create_execution_ctx();
489 let empty1 = StructArray::try_new(
490 FieldNames::from(Vec::<FieldName>::new()),
491 Vec::new(),
492 5,
493 Validity::NonNullable,
494 )
495 .unwrap();
496
497 let empty2 = StructArray::try_new(
498 FieldNames::from(Vec::<FieldName>::new()),
499 Vec::new(),
500 5,
501 Validity::NonNullable,
502 )
503 .unwrap();
504
505 let result = empty1
506 .into_array()
507 .binary(empty2.into_array(), Operator::Eq)
508 .unwrap();
509 let expected = BoolArray::from_iter([true, true, true, true, true]);
510 assert_arrays_eq!(result, expected, &mut ctx);
511 }
512
513 #[test]
516 fn struct_compare_mixed_binary_encodings() {
517 let mut ctx = array_session().create_execution_ctx();
518 let bin_field1 = VarBinArray::from(vec![
520 "apple".as_bytes(),
521 "banana".as_bytes(),
522 "cherry".as_bytes(),
523 ]);
524 let struct1 = StructArray::from_fields(&[("data", bin_field1.into_array())]).unwrap();
525
526 let bin_field2 = VarBinViewArray::from_iter_bin([
528 "apple".as_bytes(),
529 "banana".as_bytes(),
530 "durian".as_bytes(),
531 ]);
532 let struct2 = StructArray::from_fields(&[("data", bin_field2.into_array())]).unwrap();
533
534 let result = struct1
535 .into_array()
536 .binary(struct2.into_array(), Operator::Eq)
537 .unwrap();
538 let expected = BoolArray::from_iter([true, true, false]);
539 assert_arrays_eq!(result, expected, &mut ctx);
540 }
541
542 #[test]
546 fn scalar_cmp_incompatible_extension_types_errors() {
547 let ms_scalar = Scalar::extension::<Timestamp>(
548 TimestampOptions {
549 unit: TimeUnit::Milliseconds,
550 tz: None,
551 },
552 Scalar::from(1704067200000i64),
553 );
554 let s_scalar = Scalar::extension::<Timestamp>(
555 TimestampOptions {
556 unit: TimeUnit::Seconds,
557 tz: None,
558 },
559 Scalar::from(1704067200i64),
560 );
561
562 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
564 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
565 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
566 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
567 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
568 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
569 }
570
571 #[test]
572 fn test_empty_list() {
573 let ctx = &mut array_session().create_execution_ctx();
574 let list = ListViewArray::new(
575 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
576 buffer![0i32, 0i32, 0i32].into_array(),
577 buffer![0i32, 0i32, 0i32].into_array(),
578 Validity::AllValid,
579 );
580
581 let result = list
582 .clone()
583 .into_array()
584 .binary(list.into_array(), Operator::Eq)
585 .unwrap();
586 assert!(result.execute_scalar(0, ctx).unwrap().is_valid());
587 assert!(result.execute_scalar(1, ctx).unwrap().is_valid());
588 assert!(result.execute_scalar(2, ctx).unwrap().is_valid());
589 }
590}