1use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::array::ArrayView;
19use crate::array::VTable;
20use crate::arrays::Constant;
21use crate::arrays::ConstantArray;
22use crate::arrays::ScalarFnVTable;
23use crate::arrays::scalar_fn::ExactScalarFn;
24use crate::arrays::scalar_fn::ScalarFnArrayExt;
25use crate::arrays::scalar_fn::ScalarFnArrayView;
26use crate::arrow::Datum;
27use crate::arrow::IntoArrowArray;
28use crate::arrow::from_arrow_array_with_len;
29use crate::dtype::DType;
30use crate::dtype::Nullability;
31use crate::kernel::ExecuteParentKernel;
32use crate::scalar::Scalar;
33use crate::scalar_fn::fns::binary::Binary;
34use crate::scalar_fn::fns::operators::CompareOperator;
35
36pub trait CompareKernel: VTable {
42 fn compare(
43 lhs: ArrayView<'_, Self>,
44 rhs: &ArrayRef,
45 operator: CompareOperator,
46 ctx: &mut ExecutionCtx,
47 ) -> VortexResult<Option<ArrayRef>>;
48}
49
50#[derive(Default, Debug)]
56pub struct CompareExecuteAdaptor<V>(pub V);
57
58impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
59where
60 V: CompareKernel,
61{
62 type Parent = ExactScalarFn<Binary>;
63
64 fn execute_parent(
65 &self,
66 array: ArrayView<'_, V>,
67 parent: ScalarFnArrayView<'_, Binary>,
68 child_idx: usize,
69 ctx: &mut ExecutionCtx,
70 ) -> VortexResult<Option<ArrayRef>> {
71 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
73 return Ok(None);
74 };
75
76 let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
78 return Ok(None);
79 };
80 let (cmp_op, other) = match child_idx {
83 0 => (cmp_op, scalar_fn_array.get_child(1)),
84 1 => (cmp_op.swap(), scalar_fn_array.get_child(0)),
85 _ => return Ok(None),
86 };
87
88 let len = array.len();
89 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
90
91 if len == 0 {
93 return Ok(Some(
94 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
95 ));
96 }
97
98 if other.as_constant().is_some_and(|s| s.is_null()) {
100 return Ok(Some(
101 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
102 ));
103 }
104
105 V::compare(array, other, cmp_op, ctx)
106 }
107}
108
109pub(crate) fn execute_compare(
114 lhs: &ArrayRef,
115 rhs: &ArrayRef,
116 op: CompareOperator,
117) -> VortexResult<ArrayRef> {
118 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
119
120 if lhs.is_empty() {
121 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
122 }
123
124 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
125 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
126 if left_constant_null || right_constant_null {
127 return Ok(
128 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
129 );
130 }
131
132 if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
134 {
135 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
136 return Ok(ConstantArray::new(result, lhs.len()).into_array());
137 }
138
139 arrow_compare_arrays(lhs, rhs, op)
140}
141
142fn arrow_compare_arrays(
144 left: &ArrayRef,
145 right: &ArrayRef,
146 operator: CompareOperator,
147) -> VortexResult<ArrayRef> {
148 assert_eq!(left.len(), right.len());
149
150 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
151
152 let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
155 let rhs = right.clone().into_arrow_preferred()?;
156 let lhs = left.clone().into_arrow(rhs.data_type())?;
157
158 assert!(
159 lhs.data_type().equals_datatype(rhs.data_type()),
160 "lhs data_type: {}, rhs data_type: {}",
161 lhs.data_type(),
162 rhs.data_type()
163 );
164
165 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
166 } else {
167 let lhs = Datum::try_new(left)?;
169 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
170
171 match operator {
172 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
173 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
174 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
175 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
176 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
177 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
178 }
179 };
180
181 from_arrow_array_with_len(&arrow_array, left.len(), nullable)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
185 if lhs.is_null() | rhs.is_null() {
186 return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
187 }
188
189 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
190
191 let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
193 vortex_err!(
194 "Cannot compare scalars with incompatible types: {} and {}",
195 lhs.dtype(),
196 rhs.dtype()
197 )
198 })?;
199
200 let b = match operator {
201 CompareOperator::Eq => ordering.is_eq(),
202 CompareOperator::NotEq => ordering.is_ne(),
203 CompareOperator::Gt => ordering.is_gt(),
204 CompareOperator::Gte => ordering.is_ge(),
205 CompareOperator::Lt => ordering.is_lt(),
206 CompareOperator::Lte => ordering.is_le(),
207 };
208
209 Ok(Scalar::bool(b, nullability))
210}
211
212pub fn compare_nested_arrow_arrays(
220 lhs: &dyn arrow_array::Array,
221 rhs: &dyn arrow_array::Array,
222 operator: CompareOperator,
223) -> VortexResult<BooleanArray> {
224 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
225
226 let cmp_fn = match operator {
227 CompareOperator::Eq => Ordering::is_eq,
228 CompareOperator::NotEq => Ordering::is_ne,
229 CompareOperator::Gt => Ordering::is_gt,
230 CompareOperator::Gte => Ordering::is_ge,
231 CompareOperator::Lt => Ordering::is_lt,
232 CompareOperator::Lte => Ordering::is_le,
233 };
234
235 let values = (0..lhs.len())
236 .map(|i| cmp_fn(compare_arrays_at(i, i)))
237 .collect();
238 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
239
240 Ok(BooleanArray::new(values, nulls))
241}
242
243#[cfg(test)]
244mod tests {
245 use std::sync::Arc;
246
247 use rstest::rstest;
248 use vortex_buffer::buffer;
249
250 use crate::ArrayRef;
251 use crate::IntoArray;
252 use crate::LEGACY_SESSION;
253 use crate::ToCanonical;
254 use crate::VortexSessionExecute;
255 use crate::arrays::BoolArray;
256 use crate::arrays::ListArray;
257 use crate::arrays::ListViewArray;
258 use crate::arrays::PrimitiveArray;
259 use crate::arrays::StructArray;
260 use crate::arrays::VarBinArray;
261 use crate::arrays::VarBinViewArray;
262 use crate::assert_arrays_eq;
263 use crate::builtins::ArrayBuiltins;
264 use crate::dtype::DType;
265 use crate::dtype::FieldName;
266 use crate::dtype::FieldNames;
267 use crate::dtype::Nullability;
268 use crate::dtype::PType;
269 use crate::extension::datetime::TimeUnit;
270 use crate::extension::datetime::Timestamp;
271 use crate::extension::datetime::TimestampOptions;
272 use crate::scalar::Scalar;
273 use crate::scalar_fn::fns::binary::compare::ConstantArray;
274 use crate::scalar_fn::fns::binary::scalar_cmp;
275 use crate::scalar_fn::fns::operators::CompareOperator;
276 use crate::scalar_fn::fns::operators::Operator;
277 use crate::test_harness::to_int_indices;
278 use crate::validity::Validity;
279
280 #[test]
281 fn test_bool_basic_comparisons() {
282 use vortex_buffer::BitBuffer;
283
284 let arr = BoolArray::new(
285 BitBuffer::from_iter([true, true, false, true, false]),
286 Validity::from_iter([false, true, true, true, true]),
287 );
288
289 let matches = arr
290 .clone()
291 .into_array()
292 .binary(arr.clone().into_array(), Operator::Eq)
293 .unwrap()
294 .to_bool();
295 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
296
297 let matches = arr
298 .clone()
299 .into_array()
300 .binary(arr.clone().into_array(), Operator::NotEq)
301 .unwrap()
302 .to_bool();
303 let empty: [u64; 0] = [];
304 assert_eq!(to_int_indices(matches).unwrap(), empty);
305
306 let other = BoolArray::new(
307 BitBuffer::from_iter([false, false, false, true, true]),
308 Validity::from_iter([false, true, true, true, true]),
309 );
310
311 let matches = arr
312 .clone()
313 .into_array()
314 .binary(other.clone().into_array(), Operator::Lte)
315 .unwrap()
316 .to_bool();
317 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
318
319 let matches = arr
320 .clone()
321 .into_array()
322 .binary(other.clone().into_array(), Operator::Lt)
323 .unwrap()
324 .to_bool();
325 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
326
327 let matches = other
328 .clone()
329 .into_array()
330 .binary(arr.clone().into_array(), Operator::Gte)
331 .unwrap()
332 .to_bool();
333 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
334
335 let matches = other
336 .into_array()
337 .binary(arr.into_array(), Operator::Gt)
338 .unwrap()
339 .to_bool();
340 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
341 }
342
343 #[test]
344 fn constant_compare() {
345 let left = ConstantArray::new(Scalar::from(2u32), 10);
346 let right = ConstantArray::new(Scalar::from(10u32), 10);
347
348 let result = left
349 .into_array()
350 .binary(right.into_array(), Operator::Gt)
351 .unwrap();
352 assert_eq!(result.len(), 10);
353 let scalar = result
354 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
355 .unwrap();
356 assert_eq!(scalar.as_bool().value(), Some(false));
357 }
358
359 #[rstest]
360 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
361 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
362 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
363 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
364 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
365 let res = left.binary(right, Operator::Eq).unwrap();
366 let expected = BoolArray::from_iter([true, true]);
367 assert_arrays_eq!(res, expected);
368 }
369
370 #[ignore = "Arrow's ListView cannot be compared"]
371 #[test]
372 fn test_list_array_comparison() {
373 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
374 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
375 let list1 = ListArray::try_new(
376 values1.into_array(),
377 offsets1.into_array(),
378 Validity::NonNullable,
379 )
380 .unwrap();
381
382 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
383 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
384 let list2 = ListArray::try_new(
385 values2.into_array(),
386 offsets2.into_array(),
387 Validity::NonNullable,
388 )
389 .unwrap();
390
391 let result = list1
392 .clone()
393 .into_array()
394 .binary(list2.clone().into_array(), Operator::Eq)
395 .unwrap();
396 let expected = BoolArray::from_iter([true, true, false]);
397 assert_arrays_eq!(result, expected);
398
399 let result = list1
400 .clone()
401 .into_array()
402 .binary(list2.clone().into_array(), Operator::NotEq)
403 .unwrap();
404 let expected = BoolArray::from_iter([false, false, true]);
405 assert_arrays_eq!(result, expected);
406
407 let result = list1
408 .into_array()
409 .binary(list2.into_array(), Operator::Lt)
410 .unwrap();
411 let expected = BoolArray::from_iter([false, false, true]);
412 assert_arrays_eq!(result, expected);
413 }
414
415 #[ignore = "Arrow's ListView cannot be compared"]
416 #[test]
417 fn test_list_array_constant_comparison() {
418 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
419 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
420 let list = ListArray::try_new(
421 values.into_array(),
422 offsets.into_array(),
423 Validity::NonNullable,
424 )
425 .unwrap();
426
427 let list_scalar = Scalar::list(
428 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
429 vec![3i32.into(), 4i32.into()],
430 Nullability::NonNullable,
431 );
432 let constant = ConstantArray::new(list_scalar, 3);
433
434 let result = list
435 .into_array()
436 .binary(constant.into_array(), Operator::Eq)
437 .unwrap();
438 let expected = BoolArray::from_iter([false, true, false]);
439 assert_arrays_eq!(result, expected);
440 }
441
442 #[test]
443 fn test_struct_array_comparison() {
444 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
445 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
446
447 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
448 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
449
450 let struct1 = StructArray::from_fields(&[
451 ("bool_col", bool_field1.into_array()),
452 ("int_col", int_field1.into_array()),
453 ])
454 .unwrap();
455
456 let struct2 = StructArray::from_fields(&[
457 ("bool_col", bool_field2.into_array()),
458 ("int_col", int_field2.into_array()),
459 ])
460 .unwrap();
461
462 let result = struct1
463 .clone()
464 .into_array()
465 .binary(struct2.clone().into_array(), Operator::Eq)
466 .unwrap();
467 let expected = BoolArray::from_iter([true, true, false]);
468 assert_arrays_eq!(result, expected);
469
470 let result = struct1
471 .into_array()
472 .binary(struct2.into_array(), Operator::Gt)
473 .unwrap();
474 let expected = BoolArray::from_iter([false, false, true]);
475 assert_arrays_eq!(result, expected);
476 }
477
478 #[test]
479 fn test_empty_struct_compare() {
480 let empty1 = StructArray::try_new(
481 FieldNames::from(Vec::<FieldName>::new()),
482 Vec::new(),
483 5,
484 Validity::NonNullable,
485 )
486 .unwrap();
487
488 let empty2 = StructArray::try_new(
489 FieldNames::from(Vec::<FieldName>::new()),
490 Vec::new(),
491 5,
492 Validity::NonNullable,
493 )
494 .unwrap();
495
496 let result = empty1
497 .into_array()
498 .binary(empty2.into_array(), Operator::Eq)
499 .unwrap();
500 let expected = BoolArray::from_iter([true, true, true, true, true]);
501 assert_arrays_eq!(result, expected);
502 }
503
504 #[test]
508 fn scalar_cmp_incompatible_extension_types_errors() {
509 let ms_scalar = Scalar::extension::<Timestamp>(
510 TimestampOptions {
511 unit: TimeUnit::Milliseconds,
512 tz: None,
513 },
514 Scalar::from(1704067200000i64),
515 );
516 let s_scalar = Scalar::extension::<Timestamp>(
517 TimestampOptions {
518 unit: TimeUnit::Seconds,
519 tz: None,
520 },
521 Scalar::from(1704067200i64),
522 );
523
524 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
526 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
527 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
528 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
529 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
530 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
531 }
532
533 #[test]
534 fn test_empty_list() {
535 let list = ListViewArray::new(
536 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
537 buffer![0i32, 0i32, 0i32].into_array(),
538 buffer![0i32, 0i32, 0i32].into_array(),
539 Validity::AllValid,
540 );
541
542 let result = list
543 .clone()
544 .into_array()
545 .binary(list.into_array(), Operator::Eq)
546 .unwrap();
547 assert!(
548 result
549 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
550 .unwrap()
551 .is_valid()
552 );
553 assert!(
554 result
555 .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
556 .unwrap()
557 .is_valid()
558 );
559 assert!(
560 result
561 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
562 .unwrap()
563 .is_valid()
564 );
565 }
566}