1use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::arrays::Constant;
19use crate::arrays::ConstantArray;
20use crate::arrays::ScalarFnVTable;
21use crate::arrays::scalar_fn::ExactScalarFn;
22use crate::arrays::scalar_fn::ScalarFnArrayView;
23use crate::arrow::Datum;
24use crate::arrow::IntoArrowArray;
25use crate::arrow::from_arrow_array_with_len;
26use crate::dtype::DType;
27use crate::dtype::Nullability;
28use crate::kernel::ExecuteParentKernel;
29use crate::scalar::Scalar;
30use crate::scalar_fn::fns::binary::Binary;
31use crate::scalar_fn::fns::operators::CompareOperator;
32use crate::vtable::VTable;
33
34pub trait CompareKernel: VTable {
40 fn compare(
41 lhs: &Self::Array,
42 rhs: &ArrayRef,
43 operator: CompareOperator,
44 ctx: &mut ExecutionCtx,
45 ) -> VortexResult<Option<ArrayRef>>;
46}
47
48#[derive(Default, Debug)]
54pub struct CompareExecuteAdaptor<V>(pub V);
55
56impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
57where
58 V: CompareKernel,
59{
60 type Parent = ExactScalarFn<Binary>;
61
62 fn execute_parent(
63 &self,
64 array: &V::Array,
65 parent: ScalarFnArrayView<'_, Binary>,
66 child_idx: usize,
67 ctx: &mut ExecutionCtx,
68 ) -> VortexResult<Option<ArrayRef>> {
69 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
71 return Ok(None);
72 };
73
74 let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
76 return Ok(None);
77 };
78 let children = scalar_fn_array.children();
79
80 let (cmp_op, other) = match child_idx {
83 0 => (cmp_op, &children[1]),
84 1 => (cmp_op.swap(), &children[0]),
85 _ => return Ok(None),
86 };
87
88 let len = array.len();
89 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
90
91 if len == 0 {
93 return Ok(Some(
94 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
95 ));
96 }
97
98 if other.as_constant().is_some_and(|s| s.is_null()) {
100 return Ok(Some(
101 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
102 ));
103 }
104
105 V::compare(array, other, cmp_op, ctx)
106 }
107}
108
109pub(crate) fn execute_compare(
114 lhs: &ArrayRef,
115 rhs: &ArrayRef,
116 op: CompareOperator,
117) -> VortexResult<ArrayRef> {
118 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
119
120 if lhs.is_empty() {
121 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
122 }
123
124 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
125 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
126 if left_constant_null || right_constant_null {
127 return Ok(
128 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
129 );
130 }
131
132 if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
134 {
135 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
136 return Ok(ConstantArray::new(result, lhs.len()).into_array());
137 }
138
139 arrow_compare_arrays(lhs, rhs, op)
140}
141
142fn arrow_compare_arrays(
144 left: &ArrayRef,
145 right: &ArrayRef,
146 operator: CompareOperator,
147) -> VortexResult<ArrayRef> {
148 assert_eq!(left.len(), right.len());
149
150 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
151
152 let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
155 let rhs = right.to_array().into_arrow_preferred()?;
156 let lhs = left.to_array().into_arrow(rhs.data_type())?;
157
158 assert!(
159 lhs.data_type().equals_datatype(rhs.data_type()),
160 "lhs data_type: {}, rhs data_type: {}",
161 lhs.data_type(),
162 rhs.data_type()
163 );
164
165 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
166 } else {
167 let lhs = Datum::try_new(left)?;
169 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
170
171 match operator {
172 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
173 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
174 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
175 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
176 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
177 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
178 }
179 };
180
181 from_arrow_array_with_len(&arrow_array, left.len(), nullable)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
185 if lhs.is_null() | rhs.is_null() {
186 return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
187 }
188
189 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
190
191 let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
193 vortex_err!(
194 "Cannot compare scalars with incompatible types: {} and {}",
195 lhs.dtype(),
196 rhs.dtype()
197 )
198 })?;
199
200 let b = match operator {
201 CompareOperator::Eq => ordering.is_eq(),
202 CompareOperator::NotEq => ordering.is_ne(),
203 CompareOperator::Gt => ordering.is_gt(),
204 CompareOperator::Gte => ordering.is_ge(),
205 CompareOperator::Lt => ordering.is_lt(),
206 CompareOperator::Lte => ordering.is_le(),
207 };
208
209 Ok(Scalar::bool(b, nullability))
210}
211
212pub fn compare_nested_arrow_arrays(
220 lhs: &dyn arrow_array::Array,
221 rhs: &dyn arrow_array::Array,
222 operator: CompareOperator,
223) -> VortexResult<BooleanArray> {
224 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
225
226 let cmp_fn = match operator {
227 CompareOperator::Eq => Ordering::is_eq,
228 CompareOperator::NotEq => Ordering::is_ne,
229 CompareOperator::Gt => Ordering::is_gt,
230 CompareOperator::Gte => Ordering::is_ge,
231 CompareOperator::Lt => Ordering::is_lt,
232 CompareOperator::Lte => Ordering::is_le,
233 };
234
235 let values = (0..lhs.len())
236 .map(|i| cmp_fn(compare_arrays_at(i, i)))
237 .collect();
238 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
239
240 Ok(BooleanArray::new(values, nulls))
241}
242
243#[cfg(test)]
244mod tests {
245 use std::sync::Arc;
246
247 use rstest::rstest;
248 use vortex_buffer::buffer;
249
250 use crate::ArrayRef;
251 use crate::IntoArray;
252 use crate::ToCanonical;
253 use crate::arrays::BoolArray;
254 use crate::arrays::ListArray;
255 use crate::arrays::ListViewArray;
256 use crate::arrays::PrimitiveArray;
257 use crate::arrays::StructArray;
258 use crate::arrays::VarBinArray;
259 use crate::arrays::VarBinViewArray;
260 use crate::assert_arrays_eq;
261 use crate::builtins::ArrayBuiltins;
262 use crate::dtype::DType;
263 use crate::dtype::FieldName;
264 use crate::dtype::FieldNames;
265 use crate::dtype::Nullability;
266 use crate::dtype::PType;
267 use crate::extension::datetime::TimeUnit;
268 use crate::extension::datetime::Timestamp;
269 use crate::extension::datetime::TimestampOptions;
270 use crate::scalar::Scalar;
271 use crate::scalar_fn::fns::binary::compare::ConstantArray;
272 use crate::scalar_fn::fns::binary::scalar_cmp;
273 use crate::scalar_fn::fns::operators::CompareOperator;
274 use crate::scalar_fn::fns::operators::Operator;
275 use crate::test_harness::to_int_indices;
276 use crate::validity::Validity;
277
278 #[test]
279 fn test_bool_basic_comparisons() {
280 use vortex_buffer::BitBuffer;
281
282 let arr = BoolArray::new(
283 BitBuffer::from_iter([true, true, false, true, false]),
284 Validity::from_iter([false, true, true, true, true]),
285 );
286
287 let matches = arr
288 .clone()
289 .into_array()
290 .binary(arr.clone().into_array(), Operator::Eq)
291 .unwrap()
292 .to_bool();
293 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
294
295 let matches = arr
296 .clone()
297 .into_array()
298 .binary(arr.clone().into_array(), Operator::NotEq)
299 .unwrap()
300 .to_bool();
301 let empty: [u64; 0] = [];
302 assert_eq!(to_int_indices(matches).unwrap(), empty);
303
304 let other = BoolArray::new(
305 BitBuffer::from_iter([false, false, false, true, true]),
306 Validity::from_iter([false, true, true, true, true]),
307 );
308
309 let matches = arr
310 .clone()
311 .into_array()
312 .binary(other.clone().into_array(), Operator::Lte)
313 .unwrap()
314 .to_bool();
315 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
316
317 let matches = arr
318 .clone()
319 .into_array()
320 .binary(other.clone().into_array(), Operator::Lt)
321 .unwrap()
322 .to_bool();
323 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
324
325 let matches = other
326 .clone()
327 .into_array()
328 .binary(arr.clone().into_array(), Operator::Gte)
329 .unwrap()
330 .to_bool();
331 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
332
333 let matches = other
334 .into_array()
335 .binary(arr.into_array(), Operator::Gt)
336 .unwrap()
337 .to_bool();
338 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
339 }
340
341 #[test]
342 fn constant_compare() {
343 let left = ConstantArray::new(Scalar::from(2u32), 10);
344 let right = ConstantArray::new(Scalar::from(10u32), 10);
345
346 let result = left
347 .into_array()
348 .binary(right.into_array(), Operator::Gt)
349 .unwrap();
350 assert_eq!(result.len(), 10);
351 let scalar = result.scalar_at(0).unwrap();
352 assert_eq!(scalar.as_bool().value(), Some(false));
353 }
354
355 #[rstest]
356 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
357 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
358 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
359 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
360 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
361 let res = left.binary(right, Operator::Eq).unwrap();
362 let expected = BoolArray::from_iter([true, true]);
363 assert_arrays_eq!(res, expected);
364 }
365
366 #[ignore = "Arrow's ListView cannot be compared"]
367 #[test]
368 fn test_list_array_comparison() {
369 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
370 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
371 let list1 = ListArray::try_new(
372 values1.into_array(),
373 offsets1.into_array(),
374 Validity::NonNullable,
375 )
376 .unwrap();
377
378 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
379 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
380 let list2 = ListArray::try_new(
381 values2.into_array(),
382 offsets2.into_array(),
383 Validity::NonNullable,
384 )
385 .unwrap();
386
387 let result = list1
388 .clone()
389 .into_array()
390 .binary(list2.clone().into_array(), Operator::Eq)
391 .unwrap();
392 let expected = BoolArray::from_iter([true, true, false]);
393 assert_arrays_eq!(result, expected);
394
395 let result = list1
396 .clone()
397 .into_array()
398 .binary(list2.clone().into_array(), Operator::NotEq)
399 .unwrap();
400 let expected = BoolArray::from_iter([false, false, true]);
401 assert_arrays_eq!(result, expected);
402
403 let result = list1
404 .into_array()
405 .binary(list2.into_array(), Operator::Lt)
406 .unwrap();
407 let expected = BoolArray::from_iter([false, false, true]);
408 assert_arrays_eq!(result, expected);
409 }
410
411 #[ignore = "Arrow's ListView cannot be compared"]
412 #[test]
413 fn test_list_array_constant_comparison() {
414 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
415 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
416 let list = ListArray::try_new(
417 values.into_array(),
418 offsets.into_array(),
419 Validity::NonNullable,
420 )
421 .unwrap();
422
423 let list_scalar = Scalar::list(
424 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
425 vec![3i32.into(), 4i32.into()],
426 Nullability::NonNullable,
427 );
428 let constant = ConstantArray::new(list_scalar, 3);
429
430 let result = list
431 .into_array()
432 .binary(constant.into_array(), Operator::Eq)
433 .unwrap();
434 let expected = BoolArray::from_iter([false, true, false]);
435 assert_arrays_eq!(result, expected);
436 }
437
438 #[test]
439 fn test_struct_array_comparison() {
440 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
441 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
442
443 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
444 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
445
446 let struct1 = StructArray::from_fields(&[
447 ("bool_col", bool_field1.into_array()),
448 ("int_col", int_field1.into_array()),
449 ])
450 .unwrap();
451
452 let struct2 = StructArray::from_fields(&[
453 ("bool_col", bool_field2.into_array()),
454 ("int_col", int_field2.into_array()),
455 ])
456 .unwrap();
457
458 let result = struct1
459 .clone()
460 .into_array()
461 .binary(struct2.clone().into_array(), Operator::Eq)
462 .unwrap();
463 let expected = BoolArray::from_iter([true, true, false]);
464 assert_arrays_eq!(result, expected);
465
466 let result = struct1
467 .into_array()
468 .binary(struct2.into_array(), Operator::Gt)
469 .unwrap();
470 let expected = BoolArray::from_iter([false, false, true]);
471 assert_arrays_eq!(result, expected);
472 }
473
474 #[test]
475 fn test_empty_struct_compare() {
476 let empty1 = StructArray::try_new(
477 FieldNames::from(Vec::<FieldName>::new()),
478 Vec::new(),
479 5,
480 Validity::NonNullable,
481 )
482 .unwrap();
483
484 let empty2 = StructArray::try_new(
485 FieldNames::from(Vec::<FieldName>::new()),
486 Vec::new(),
487 5,
488 Validity::NonNullable,
489 )
490 .unwrap();
491
492 let result = empty1
493 .into_array()
494 .binary(empty2.into_array(), Operator::Eq)
495 .unwrap();
496 let expected = BoolArray::from_iter([true, true, true, true, true]);
497 assert_arrays_eq!(result, expected);
498 }
499
500 #[test]
504 fn scalar_cmp_incompatible_extension_types_errors() {
505 let ms_scalar = Scalar::extension::<Timestamp>(
506 TimestampOptions {
507 unit: TimeUnit::Milliseconds,
508 tz: None,
509 },
510 Scalar::from(1704067200000i64),
511 );
512 let s_scalar = Scalar::extension::<Timestamp>(
513 TimestampOptions {
514 unit: TimeUnit::Seconds,
515 tz: None,
516 },
517 Scalar::from(1704067200i64),
518 );
519
520 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
522 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
523 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
524 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
525 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
526 assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
527 }
528
529 #[test]
530 fn test_empty_list() {
531 let list = ListViewArray::new(
532 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
533 buffer![0i32, 0i32, 0i32].into_array(),
534 buffer![0i32, 0i32, 0i32].into_array(),
535 Validity::AllValid,
536 );
537
538 let result = list
539 .clone()
540 .into_array()
541 .binary(list.into_array(), Operator::Eq)
542 .unwrap();
543 assert!(result.scalar_at(0).unwrap().is_valid());
544 assert!(result.scalar_at(1).unwrap().is_valid());
545 assert!(result.scalar_at(2).unwrap().is_valid());
546 }
547}