1use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_array::BooleanArray;
11use arrow_buffer::NullBuffer;
12use arrow_ord::cmp;
13use arrow_ord::ord::make_comparator;
14use arrow_schema::SortOptions;
15use vortex_buffer::BitBuffer;
16use vortex_dtype::{DType, IntegerPType, Nullability};
17use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
18use vortex_scalar::Scalar;
19
20use crate::arrays::ConstantArray;
21use crate::arrow::{Datum, IntoArrowArray, from_arrow_array_with_len};
22use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
23use crate::vtable::VTable;
24use crate::{Array, ArrayRef, Canonical, IntoArray};
25
26static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
27 let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
28 for kernel in inventory::iter::<CompareKernelRef> {
29 compute.register_kernel(kernel.0.clone());
30 }
31 compute
32});
33
34pub(crate) fn warm_up_vtable() -> usize {
35 COMPARE_FN.kernels().len()
36}
37
38pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
41 COMPARE_FN
42 .invoke(&InvocationArgs {
43 inputs: &[left.into(), right.into()],
44 options: &operator,
45 })?
46 .unwrap_array()
47}
48
49#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
50pub enum Operator {
51 Eq,
53 NotEq,
55 Gt,
57 Gte,
59 Lt,
61 Lte,
63}
64
65impl Display for Operator {
66 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
67 let display = match &self {
68 Operator::Eq => "=",
69 Operator::NotEq => "!=",
70 Operator::Gt => ">",
71 Operator::Gte => ">=",
72 Operator::Lt => "<",
73 Operator::Lte => "<=",
74 };
75 Display::fmt(display, f)
76 }
77}
78
79impl Operator {
80 pub fn inverse(self) -> Self {
81 match self {
82 Operator::Eq => Operator::NotEq,
83 Operator::NotEq => Operator::Eq,
84 Operator::Gt => Operator::Lte,
85 Operator::Gte => Operator::Lt,
86 Operator::Lt => Operator::Gte,
87 Operator::Lte => Operator::Gt,
88 }
89 }
90
91 pub fn swap(self) -> Self {
93 match self {
94 Operator::Eq => Operator::Eq,
95 Operator::NotEq => Operator::NotEq,
96 Operator::Gt => Operator::Lt,
97 Operator::Gte => Operator::Lte,
98 Operator::Lt => Operator::Gt,
99 Operator::Lte => Operator::Gte,
100 }
101 }
102}
103
104pub struct CompareKernelRef(ArcRef<dyn Kernel>);
105inventory::collect!(CompareKernelRef);
106
107pub trait CompareKernel: VTable {
108 fn compare(
109 &self,
110 lhs: &Self::Array,
111 rhs: &dyn Array,
112 operator: Operator,
113 ) -> VortexResult<Option<ArrayRef>>;
114}
115
116#[derive(Debug)]
117pub struct CompareKernelAdapter<V: VTable>(pub V);
118
119impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
120 pub const fn lift(&'static self) -> CompareKernelRef {
121 CompareKernelRef(ArcRef::new_ref(self))
122 }
123}
124
125impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
126 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
127 let inputs = CompareArgs::try_from(args)?;
128 let Some(array) = inputs.lhs.as_opt::<V>() else {
129 return Ok(None);
130 };
131 Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
132 }
133}
134
135struct Compare;
136
137impl ComputeFnVTable for Compare {
138 fn invoke(
139 &self,
140 args: &InvocationArgs,
141 kernels: &[ArcRef<dyn Kernel>],
142 ) -> VortexResult<Output> {
143 let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
144
145 let return_dtype = self.return_dtype(args)?;
146
147 if lhs.is_empty() {
148 return Ok(Canonical::empty(&return_dtype).into_array().into());
149 }
150
151 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
152 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
153 if left_constant_null || right_constant_null {
154 return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
155 .into_array()
156 .into());
157 }
158
159 let right_is_constant = rhs.is_constant();
160
161 if lhs.is_constant() && !right_is_constant {
163 return Ok(compare(rhs, lhs, operator.swap())?.into());
164 }
165
166 for kernel in kernels {
168 if let Some(output) = kernel.invoke(args)? {
169 return Ok(output);
170 }
171 }
172 if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
173 return Ok(output);
174 }
175
176 let inverted_args = InvocationArgs {
178 inputs: &[rhs.into(), lhs.into()],
179 options: &operator.swap(),
180 };
181 for kernel in kernels {
182 if let Some(output) = kernel.invoke(&inverted_args)? {
183 return Ok(output);
184 }
185 }
186 if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
187 return Ok(output);
188 }
189
190 if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
193 log::debug!(
194 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
195 lhs.encoding_id(),
196 rhs.encoding_id(),
197 operator,
198 );
199 }
200
201 Ok(arrow_compare(lhs, rhs, operator)?.into())
203 }
204
205 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
206 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
207
208 if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
209 vortex_bail!(
210 "Cannot compare different DTypes {} and {}",
211 lhs.dtype(),
212 rhs.dtype()
213 );
214 }
215
216 Ok(DType::Bool(
217 lhs.dtype().nullability() | rhs.dtype().nullability(),
218 ))
219 }
220
221 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
222 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
223 if lhs.len() != rhs.len() {
224 vortex_bail!(
225 "Compare operations only support arrays of the same length, got {} and {}",
226 lhs.len(),
227 rhs.len()
228 );
229 }
230 Ok(lhs.len())
231 }
232
233 fn is_elementwise(&self) -> bool {
234 true
235 }
236}
237
238struct CompareArgs<'a> {
239 lhs: &'a dyn Array,
240 rhs: &'a dyn Array,
241 operator: Operator,
242}
243
244impl Options for Operator {
245 fn as_any(&self) -> &dyn Any {
246 self
247 }
248}
249
250impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
251 type Error = VortexError;
252
253 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
254 if value.inputs.len() != 2 {
255 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
256 }
257 let lhs = value.inputs[0]
258 .array()
259 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
260 let rhs = value.inputs[1]
261 .array()
262 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
263 let operator = *value
264 .options
265 .as_any()
266 .downcast_ref::<Operator>()
267 .vortex_expect("Expected options to be an operator");
268
269 Ok(CompareArgs { lhs, rhs, operator })
270 }
271}
272
273pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BitBuffer
276where
277 P: IntegerPType,
278 I: Iterator<Item = P>,
279{
280 let cmp_fn = match op {
282 Operator::Eq | Operator::Lte => |v| v == P::zero(),
283 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
284 Operator::Gte => |_| true,
285 Operator::Lt => |_| false,
286 };
287
288 lengths.map(cmp_fn).collect()
289}
290
291fn arrow_compare(
293 left: &dyn Array,
294 right: &dyn Array,
295 operator: Operator,
296) -> VortexResult<ArrayRef> {
297 assert_eq!(left.len(), right.len());
298
299 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
300
301 let array = if left.dtype().is_nested() || right.dtype().is_nested() {
302 let rhs = right.to_array().into_arrow_preferred()?;
303 let lhs = left.to_array().into_arrow(rhs.data_type())?;
304
305 assert!(
306 lhs.data_type().equals_datatype(rhs.data_type()),
307 "lhs data_type: {}, rhs data_type: {}",
308 lhs.data_type(),
309 rhs.data_type()
310 );
311
312 let cmp = make_comparator(lhs.as_ref(), rhs.as_ref(), SortOptions::default())?;
313 let len = left.len();
314 let values = (0..len)
315 .map(|i| {
316 let cmp = cmp(i, i);
317 match operator {
318 Operator::Eq => cmp.is_eq(),
319 Operator::NotEq => cmp.is_ne(),
320 Operator::Gt => cmp.is_gt(),
321 Operator::Gte => cmp.is_gt() || cmp.is_eq(),
322 Operator::Lt => cmp.is_lt(),
323 Operator::Lte => cmp.is_lt() || cmp.is_eq(),
324 }
325 })
326 .collect();
327 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
328 BooleanArray::new(values, nulls)
329 } else {
330 let lhs = Datum::try_new(left)?;
331 let rhs = Datum::try_new(right)?;
332
333 match operator {
334 Operator::Eq => cmp::eq(&lhs, &rhs)?,
335 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
336 Operator::Gt => cmp::gt(&lhs, &rhs)?,
337 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
338 Operator::Lt => cmp::lt(&lhs, &rhs)?,
339 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
340 }
341 };
342 Ok(from_arrow_array_with_len(&array, left.len(), nullable))
343}
344
345pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
346 if lhs.is_null() | rhs.is_null() {
347 Scalar::null(DType::Bool(Nullability::Nullable))
348 } else {
349 let b = match operator {
350 Operator::Eq => lhs == rhs,
351 Operator::NotEq => lhs != rhs,
352 Operator::Gt => lhs > rhs,
353 Operator::Gte => lhs >= rhs,
354 Operator::Lt => lhs < rhs,
355 Operator::Lte => lhs <= rhs,
356 };
357
358 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use rstest::rstest;
365 use vortex_buffer::buffer;
366 use vortex_dtype::{FieldName, FieldNames};
367
368 use super::*;
369 use crate::ToCanonical;
370 use crate::arrays::{
371 BoolArray, ConstantArray, ListArray, ListViewArray, PrimitiveArray, StructArray,
372 VarBinArray, VarBinViewArray,
373 };
374 use crate::test_harness::to_int_indices;
375 use crate::validity::Validity;
376
377 #[test]
378 fn test_bool_basic_comparisons() {
379 let arr = BoolArray::from_bit_buffer(
380 BitBuffer::from_iter([true, true, false, true, false]),
381 Validity::from_iter([false, true, true, true, true]),
382 );
383
384 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
385 .unwrap()
386 .to_bool();
387
388 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
389
390 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
391 .unwrap()
392 .to_bool();
393 let empty: [u64; 0] = [];
394 assert_eq!(to_int_indices(matches).unwrap(), empty);
395
396 let other = BoolArray::from_bit_buffer(
397 BitBuffer::from_iter([false, false, false, true, true]),
398 Validity::from_iter([false, true, true, true, true]),
399 );
400
401 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
402 .unwrap()
403 .to_bool();
404 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
405
406 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
407 .unwrap()
408 .to_bool();
409 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
410
411 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
412 .unwrap()
413 .to_bool();
414 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
415
416 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
417 .unwrap()
418 .to_bool();
419 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
420 }
421
422 #[test]
423 fn constant_compare() {
424 let left = ConstantArray::new(Scalar::from(2u32), 10);
425 let right = ConstantArray::new(Scalar::from(10u32), 10);
426
427 let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
428 let res = compare.as_constant().unwrap();
429 assert_eq!(res.as_bool().value(), Some(false));
430 assert_eq!(compare.len(), 10);
431
432 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
433 let res = compare.as_constant().unwrap();
434 assert_eq!(res.as_bool().value(), Some(false));
435 assert_eq!(compare.len(), 10);
436 }
437
438 #[rstest]
439 #[case(Operator::Eq, vec![false, false, false, true])]
440 #[case(Operator::NotEq, vec![true, true, true, false])]
441 #[case(Operator::Gt, vec![true, true, true, false])]
442 #[case(Operator::Gte, vec![true, true, true, true])]
443 #[case(Operator::Lt, vec![false, false, false, false])]
444 #[case(Operator::Lte, vec![false, false, false, true])]
445 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
446 let lengths: Vec<i32> = vec![1, 5, 7, 0];
447
448 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
449 assert_eq!(Vec::from_iter(output.iter()), expected);
450 }
451
452 #[rstest]
453 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
454 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
455 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
456 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
457 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
458 let res = compare(&left, &right, Operator::Eq).unwrap();
459 assert_eq!(res.to_bool().bit_buffer().true_count(), left.len());
460 }
461
462 #[ignore = "Arrow's ListView cannot be compared"]
463 #[test]
464 fn test_list_array_comparison() {
465 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
467 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
468 let list1 = ListArray::try_new(
469 values1.into_array(),
470 offsets1.into_array(),
471 Validity::NonNullable,
472 )
473 .unwrap();
474
475 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
476 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
477 let list2 = ListArray::try_new(
478 values2.into_array(),
479 offsets2.into_array(),
480 Validity::NonNullable,
481 )
482 .unwrap();
483
484 let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap();
486 let bool_result = result.to_bool();
487 assert!(bool_result.bit_buffer().value(0)); assert!(bool_result.bit_buffer().value(1)); assert!(!bool_result.bit_buffer().value(2)); let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap();
493 let bool_result = result.to_bool();
494 assert!(!bool_result.bit_buffer().value(0));
495 assert!(!bool_result.bit_buffer().value(1));
496 assert!(bool_result.bit_buffer().value(2));
497
498 let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap();
500 let bool_result = result.to_bool();
501 assert!(!bool_result.bit_buffer().value(0)); assert!(!bool_result.bit_buffer().value(1)); assert!(bool_result.bit_buffer().value(2)); }
505
506 #[ignore = "Arrow's ListView cannot be compared"]
507 #[test]
508 fn test_list_array_constant_comparison() {
509 use std::sync::Arc;
510
511 use vortex_dtype::{DType, PType};
512
513 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
515 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
516 let list = ListArray::try_new(
517 values.into_array(),
518 offsets.into_array(),
519 Validity::NonNullable,
520 )
521 .unwrap();
522
523 let list_scalar = Scalar::list(
525 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
526 vec![3i32.into(), 4i32.into()],
527 Nullability::NonNullable,
528 );
529 let constant = ConstantArray::new(list_scalar, 3);
530
531 let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap();
533 let bool_result = result.to_bool();
534 assert!(!bool_result.bit_buffer().value(0)); assert!(bool_result.bit_buffer().value(1)); assert!(!bool_result.bit_buffer().value(2)); }
538
539 #[test]
540 fn test_struct_array_comparison() {
541 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
543 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
544
545 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
546 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
547
548 let struct1 = StructArray::from_fields(&[
549 ("bool_col", bool_field1.into_array()),
550 ("int_col", int_field1.into_array()),
551 ])
552 .unwrap();
553
554 let struct2 = StructArray::from_fields(&[
555 ("bool_col", bool_field2.into_array()),
556 ("int_col", int_field2.into_array()),
557 ])
558 .unwrap();
559
560 let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap();
562 let bool_result = result.to_bool();
563 assert!(bool_result.bit_buffer().value(0)); assert!(bool_result.bit_buffer().value(1)); assert!(!bool_result.bit_buffer().value(2)); let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap();
569 let bool_result = result.to_bool();
570 assert!(!bool_result.bit_buffer().value(0)); assert!(!bool_result.bit_buffer().value(1)); assert!(bool_result.bit_buffer().value(2)); }
574
575 #[test]
576 fn test_empty_struct_compare() {
577 let empty1 = StructArray::try_new(
578 FieldNames::from(Vec::<FieldName>::new()),
579 Vec::new(),
580 5,
581 Validity::NonNullable,
582 )
583 .unwrap();
584
585 let empty2 = StructArray::try_new(
586 FieldNames::from(Vec::<FieldName>::new()),
587 Vec::new(),
588 5,
589 Validity::NonNullable,
590 )
591 .unwrap();
592
593 let result = compare(empty1.as_ref(), empty2.as_ref(), Operator::Eq).unwrap();
594 let result = result.to_bool();
595
596 for idx in 0..5 {
597 assert!(result.bit_buffer().value(idx));
598 }
599 }
600
601 #[test]
602 fn test_empty_list() {
603 let list = ListViewArray::new(
604 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
605 buffer![0i32, 0i32, 0i32].into_array(),
606 buffer![0i32, 0i32, 0i32].into_array(),
607 Validity::AllValid,
608 );
609
610 let result = compare(list.as_ref(), list.as_ref(), Operator::Eq).unwrap();
612 assert!(result.scalar_at(0).is_valid());
613 assert!(result.scalar_at(1).is_valid());
614 assert!(result.scalar_at(2).is_valid());
615 }
616}