1use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_array::{BooleanArray, Datum as ArrowDatum};
11use arrow_buffer::{BooleanBuffer, NullBuffer};
12use arrow_ord::cmp;
13use arrow_ord::ord::make_comparator;
14use arrow_schema::SortOptions;
15use vortex_dtype::{DType, IntegerPType, Nullability};
16use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
17use vortex_scalar::Scalar;
18
19use crate::arrays::ConstantArray;
20use crate::arrow::{Datum, from_arrow_array_with_len};
21use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
22use crate::vtable::VTable;
23use crate::{Array, ArrayRef, Canonical, IntoArray};
24
25static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
26 let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
27 for kernel in inventory::iter::<CompareKernelRef> {
28 compute.register_kernel(kernel.0.clone());
29 }
30 compute
31});
32
33pub(crate) fn warm_up_vtable() -> usize {
34 COMPARE_FN.kernels().len()
35}
36
37pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
40 COMPARE_FN
41 .invoke(&InvocationArgs {
42 inputs: &[left.into(), right.into()],
43 options: &operator,
44 })?
45 .unwrap_array()
46}
47
48#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
49pub enum Operator {
50 Eq,
52 NotEq,
54 Gt,
56 Gte,
58 Lt,
60 Lte,
62}
63
64impl Display for Operator {
65 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
66 let display = match &self {
67 Operator::Eq => "=",
68 Operator::NotEq => "!=",
69 Operator::Gt => ">",
70 Operator::Gte => ">=",
71 Operator::Lt => "<",
72 Operator::Lte => "<=",
73 };
74 Display::fmt(display, f)
75 }
76}
77
78impl Operator {
79 pub fn inverse(self) -> Self {
80 match self {
81 Operator::Eq => Operator::NotEq,
82 Operator::NotEq => Operator::Eq,
83 Operator::Gt => Operator::Lte,
84 Operator::Gte => Operator::Lt,
85 Operator::Lt => Operator::Gte,
86 Operator::Lte => Operator::Gt,
87 }
88 }
89
90 pub fn swap(self) -> Self {
92 match self {
93 Operator::Eq => Operator::Eq,
94 Operator::NotEq => Operator::NotEq,
95 Operator::Gt => Operator::Lt,
96 Operator::Gte => Operator::Lte,
97 Operator::Lt => Operator::Gt,
98 Operator::Lte => Operator::Gte,
99 }
100 }
101}
102
103pub struct CompareKernelRef(ArcRef<dyn Kernel>);
104inventory::collect!(CompareKernelRef);
105
106pub trait CompareKernel: VTable {
107 fn compare(
108 &self,
109 lhs: &Self::Array,
110 rhs: &dyn Array,
111 operator: Operator,
112 ) -> VortexResult<Option<ArrayRef>>;
113}
114
115#[derive(Debug)]
116pub struct CompareKernelAdapter<V: VTable>(pub V);
117
118impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
119 pub const fn lift(&'static self) -> CompareKernelRef {
120 CompareKernelRef(ArcRef::new_ref(self))
121 }
122}
123
124impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
125 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
126 let inputs = CompareArgs::try_from(args)?;
127 let Some(array) = inputs.lhs.as_opt::<V>() else {
128 return Ok(None);
129 };
130 Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
131 }
132}
133
134struct Compare;
135
136impl ComputeFnVTable for Compare {
137 fn invoke(
138 &self,
139 args: &InvocationArgs,
140 kernels: &[ArcRef<dyn Kernel>],
141 ) -> VortexResult<Output> {
142 let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
143
144 let return_dtype = self.return_dtype(args)?;
145
146 if lhs.is_empty() {
147 return Ok(Canonical::empty(&return_dtype).into_array().into());
148 }
149
150 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
151 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
152 if left_constant_null || right_constant_null {
153 return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
154 .into_array()
155 .into());
156 }
157
158 let right_is_constant = rhs.is_constant();
159
160 if lhs.is_constant() && !right_is_constant {
162 return Ok(compare(rhs, lhs, operator.swap())?.into());
163 }
164
165 for kernel in kernels {
167 if let Some(output) = kernel.invoke(args)? {
168 return Ok(output);
169 }
170 }
171 if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
172 return Ok(output);
173 }
174
175 let inverted_args = InvocationArgs {
177 inputs: &[rhs.into(), lhs.into()],
178 options: &operator.swap(),
179 };
180 for kernel in kernels {
181 if let Some(output) = kernel.invoke(&inverted_args)? {
182 return Ok(output);
183 }
184 }
185 if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
186 return Ok(output);
187 }
188
189 if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
192 log::debug!(
193 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
194 lhs.encoding_id(),
195 rhs.encoding_id(),
196 operator,
197 );
198 }
199
200 Ok(arrow_compare(lhs, rhs, operator)?.into())
202 }
203
204 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
205 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
206
207 if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
208 vortex_bail!(
209 "Cannot compare different DTypes {} and {}",
210 lhs.dtype(),
211 rhs.dtype()
212 );
213 }
214
215 Ok(DType::Bool(
216 lhs.dtype().nullability() | rhs.dtype().nullability(),
217 ))
218 }
219
220 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
221 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
222 if lhs.len() != rhs.len() {
223 vortex_bail!(
224 "Compare operations only support arrays of the same length, got {} and {}",
225 lhs.len(),
226 rhs.len()
227 );
228 }
229 Ok(lhs.len())
230 }
231
232 fn is_elementwise(&self) -> bool {
233 true
234 }
235}
236
237struct CompareArgs<'a> {
238 lhs: &'a dyn Array,
239 rhs: &'a dyn Array,
240 operator: Operator,
241}
242
243impl Options for Operator {
244 fn as_any(&self) -> &dyn Any {
245 self
246 }
247}
248
249impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
250 type Error = VortexError;
251
252 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
253 if value.inputs.len() != 2 {
254 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
255 }
256 let lhs = value.inputs[0]
257 .array()
258 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
259 let rhs = value.inputs[1]
260 .array()
261 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
262 let operator = *value
263 .options
264 .as_any()
265 .downcast_ref::<Operator>()
266 .vortex_expect("Expected options to be an operator");
267
268 Ok(CompareArgs { lhs, rhs, operator })
269 }
270}
271
272pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
275where
276 P: IntegerPType,
277 I: Iterator<Item = P>,
278{
279 let cmp_fn = match op {
281 Operator::Eq | Operator::Lte => |v| v == P::zero(),
282 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
283 Operator::Gte => |_| true,
284 Operator::Lt => |_| false,
285 };
286
287 lengths.map(cmp_fn).collect::<BooleanBuffer>()
288}
289
290fn arrow_compare(
292 left: &dyn Array,
293 right: &dyn Array,
294 operator: Operator,
295) -> VortexResult<ArrayRef> {
296 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
297
298 let array = if left.dtype().is_nested() || right.dtype().is_nested() {
299 let lhs = Datum::try_new_array(&left.to_canonical().into_array())?;
300 let (lhs, _) = lhs.get();
301 let rhs = Datum::try_new_array(&right.to_canonical().into_array())?;
302 let (rhs, _) = rhs.get();
303
304 let cmp = make_comparator(lhs, rhs, SortOptions::default())?;
305 assert_eq!(lhs.len(), rhs.len());
306 let len = lhs.len();
307 let values = (0..len)
308 .map(|i| {
309 let cmp = cmp(i, i);
310 match operator {
311 Operator::Eq => cmp.is_eq(),
312 Operator::NotEq => cmp.is_ne(),
313 Operator::Gt => cmp.is_gt(),
314 Operator::Gte => cmp.is_gt() || cmp.is_eq(),
315 Operator::Lt => cmp.is_lt(),
316 Operator::Lte => cmp.is_lt() || cmp.is_eq(),
317 }
318 })
319 .collect();
320 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
321 BooleanArray::new(values, nulls)
322 } else {
323 let lhs = Datum::try_new(left)?;
324 let rhs = Datum::try_new(right)?;
325
326 match operator {
327 Operator::Eq => cmp::eq(&lhs, &rhs)?,
328 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
329 Operator::Gt => cmp::gt(&lhs, &rhs)?,
330 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
331 Operator::Lt => cmp::lt(&lhs, &rhs)?,
332 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
333 }
334 };
335 Ok(from_arrow_array_with_len(&array, left.len(), nullable))
336}
337
338pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
339 if lhs.is_null() | rhs.is_null() {
340 Scalar::null(DType::Bool(Nullability::Nullable))
341 } else {
342 let b = match operator {
343 Operator::Eq => lhs == rhs,
344 Operator::NotEq => lhs != rhs,
345 Operator::Gt => lhs > rhs,
346 Operator::Gte => lhs >= rhs,
347 Operator::Lt => lhs < rhs,
348 Operator::Lte => lhs <= rhs,
349 };
350
351 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use arrow_buffer::BooleanBuffer;
358 use rstest::rstest;
359
360 use super::*;
361 use crate::ToCanonical;
362 use crate::arrays::{
363 BoolArray, ConstantArray, ListArray, PrimitiveArray, StructArray, VarBinArray,
364 VarBinViewArray,
365 };
366 use crate::test_harness::to_int_indices;
367 use crate::validity::Validity;
368
369 #[test]
370 fn test_bool_basic_comparisons() {
371 let arr = BoolArray::from_bool_buffer(
372 BooleanBuffer::from_iter([true, true, false, true, false]),
373 Validity::from_iter([false, true, true, true, true]),
374 );
375
376 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
377 .unwrap()
378 .to_bool();
379
380 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
381
382 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
383 .unwrap()
384 .to_bool();
385 let empty: [u64; 0] = [];
386 assert_eq!(to_int_indices(matches).unwrap(), empty);
387
388 let other = BoolArray::from_bool_buffer(
389 BooleanBuffer::from_iter([false, false, false, true, true]),
390 Validity::from_iter([false, true, true, true, true]),
391 );
392
393 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
394 .unwrap()
395 .to_bool();
396 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
397
398 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
399 .unwrap()
400 .to_bool();
401 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
402
403 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
404 .unwrap()
405 .to_bool();
406 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
407
408 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
409 .unwrap()
410 .to_bool();
411 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
412 }
413
414 #[test]
415 fn constant_compare() {
416 let left = ConstantArray::new(Scalar::from(2u32), 10);
417 let right = ConstantArray::new(Scalar::from(10u32), 10);
418
419 let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
420 let res = compare.as_constant().unwrap();
421 assert_eq!(res.as_bool().value(), Some(false));
422 assert_eq!(compare.len(), 10);
423
424 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
425 let res = compare.as_constant().unwrap();
426 assert_eq!(res.as_bool().value(), Some(false));
427 assert_eq!(compare.len(), 10);
428 }
429
430 #[rstest]
431 #[case(Operator::Eq, vec![false, false, false, true])]
432 #[case(Operator::NotEq, vec![true, true, true, false])]
433 #[case(Operator::Gt, vec![true, true, true, false])]
434 #[case(Operator::Gte, vec![true, true, true, true])]
435 #[case(Operator::Lt, vec![false, false, false, false])]
436 #[case(Operator::Lte, vec![false, false, false, true])]
437 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
438 let lengths: Vec<i32> = vec![1, 5, 7, 0];
439
440 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
441 assert_eq!(Vec::from_iter(output.iter()), expected);
442 }
443
444 #[rstest]
445 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
446 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
447 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
448 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
449 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
450 let res = compare(&left, &right, Operator::Eq).unwrap();
451 assert_eq!(res.to_bool().boolean_buffer().count_set_bits(), left.len());
452 }
453
454 #[ignore = "Arrow's ListView cannot be compared"]
455 #[test]
456 fn test_list_array_comparison() {
457 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
459 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
460 let list1 = ListArray::try_new(
461 values1.into_array(),
462 offsets1.into_array(),
463 Validity::NonNullable,
464 )
465 .unwrap();
466
467 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
468 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
469 let list2 = ListArray::try_new(
470 values2.into_array(),
471 offsets2.into_array(),
472 Validity::NonNullable,
473 )
474 .unwrap();
475
476 let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap();
478 let bool_result = result.to_bool();
479 assert!(bool_result.boolean_buffer().value(0)); assert!(bool_result.boolean_buffer().value(1)); assert!(!bool_result.boolean_buffer().value(2)); let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap();
485 let bool_result = result.to_bool();
486 assert!(!bool_result.boolean_buffer().value(0));
487 assert!(!bool_result.boolean_buffer().value(1));
488 assert!(bool_result.boolean_buffer().value(2));
489
490 let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap();
492 let bool_result = result.to_bool();
493 assert!(!bool_result.boolean_buffer().value(0)); assert!(!bool_result.boolean_buffer().value(1)); assert!(bool_result.boolean_buffer().value(2)); }
497
498 #[ignore = "Arrow's ListView cannot be compared"]
499 #[test]
500 fn test_list_array_constant_comparison() {
501 use std::sync::Arc;
502
503 use vortex_dtype::{DType, PType};
504
505 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
507 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
508 let list = ListArray::try_new(
509 values.into_array(),
510 offsets.into_array(),
511 Validity::NonNullable,
512 )
513 .unwrap();
514
515 let list_scalar = Scalar::list(
517 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
518 vec![3i32.into(), 4i32.into()],
519 Nullability::NonNullable,
520 );
521 let constant = ConstantArray::new(list_scalar, 3);
522
523 let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap();
525 let bool_result = result.to_bool();
526 assert!(!bool_result.boolean_buffer().value(0)); assert!(bool_result.boolean_buffer().value(1)); assert!(!bool_result.boolean_buffer().value(2)); }
530
531 #[test]
532 fn test_struct_array_comparison() {
533 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
535 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
536
537 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
538 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
539
540 let struct1 = StructArray::from_fields(&[
541 ("bool_col", bool_field1.into_array()),
542 ("int_col", int_field1.into_array()),
543 ])
544 .unwrap();
545
546 let struct2 = StructArray::from_fields(&[
547 ("bool_col", bool_field2.into_array()),
548 ("int_col", int_field2.into_array()),
549 ])
550 .unwrap();
551
552 let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap();
554 let bool_result = result.to_bool();
555 assert!(bool_result.boolean_buffer().value(0)); assert!(bool_result.boolean_buffer().value(1)); assert!(!bool_result.boolean_buffer().value(2)); let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap();
561 let bool_result = result.to_bool();
562 assert!(!bool_result.boolean_buffer().value(0)); assert!(!bool_result.boolean_buffer().value(1)); assert!(bool_result.boolean_buffer().value(2)); }
566}