1use arrow_array::BooleanArray;
5use arrow_ord::cmp;
6use vortex_error::VortexResult;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::arrays::ConstantVTable;
15use crate::arrays::ExactScalarFn;
16use crate::arrays::ScalarFnArrayView;
17use crate::arrays::ScalarFnVTable;
18use crate::arrow::Datum;
19use crate::arrow::IntoArrowArray;
20use crate::arrow::from_arrow_array_with_len;
21use crate::compute::compare_nested_arrow_arrays;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::kernel::ExecuteParentKernel;
25use crate::scalar::Scalar;
26use crate::scalar_fn::fns::binary::Binary;
27use crate::scalar_fn::fns::operators::CompareOperator;
28use crate::vtable::VTable;
29
30pub trait CompareKernel: VTable {
36 fn compare(
37 lhs: &Self::Array,
38 rhs: &dyn Array,
39 operator: CompareOperator,
40 ctx: &mut ExecutionCtx,
41 ) -> VortexResult<Option<ArrayRef>>;
42}
43
44#[derive(Default, Debug)]
50pub struct CompareExecuteAdaptor<V>(pub V);
51
52impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
53where
54 V: CompareKernel,
55{
56 type Parent = ExactScalarFn<Binary>;
57
58 fn execute_parent(
59 &self,
60 array: &V::Array,
61 parent: ScalarFnArrayView<'_, Binary>,
62 child_idx: usize,
63 ctx: &mut ExecutionCtx,
64 ) -> VortexResult<Option<ArrayRef>> {
65 let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
67 return Ok(None);
68 };
69
70 let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
72 return Ok(None);
73 };
74 let children = scalar_fn_array.children();
75
76 let (cmp_op, other) = match child_idx {
79 0 => (cmp_op, &children[1]),
80 1 => (cmp_op.swap(), &children[0]),
81 _ => return Ok(None),
82 };
83
84 let len = array.len();
85 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
86
87 if len == 0 {
89 return Ok(Some(
90 Canonical::empty(&DType::Bool(nullable.into())).into_array(),
91 ));
92 }
93
94 if other.as_constant().is_some_and(|s| s.is_null()) {
96 return Ok(Some(
97 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
98 ));
99 }
100
101 V::compare(array, other.as_ref(), cmp_op, ctx)
102 }
103}
104
105pub(crate) fn execute_compare(
110 lhs: &dyn Array,
111 rhs: &dyn Array,
112 op: CompareOperator,
113) -> VortexResult<ArrayRef> {
114 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
115
116 if lhs.is_empty() {
117 return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
118 }
119
120 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
121 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
122 if left_constant_null || right_constant_null {
123 return Ok(
124 ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
125 );
126 }
127
128 if let (Some(lhs_const), Some(rhs_const)) = (
130 lhs.as_opt::<ConstantVTable>(),
131 rhs.as_opt::<ConstantVTable>(),
132 ) {
133 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
134 return Ok(ConstantArray::new(result, lhs.len()).into_array());
135 }
136
137 arrow_compare_arrays(lhs, rhs, op)
138}
139
140fn arrow_compare_arrays(
142 left: &dyn Array,
143 right: &dyn Array,
144 operator: CompareOperator,
145) -> VortexResult<ArrayRef> {
146 assert_eq!(left.len(), right.len());
147
148 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
149
150 let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
153 let rhs = right.to_array().into_arrow_preferred()?;
154 let lhs = left.to_array().into_arrow(rhs.data_type())?;
155
156 assert!(
157 lhs.data_type().equals_datatype(rhs.data_type()),
158 "lhs data_type: {}, rhs data_type: {}",
159 lhs.data_type(),
160 rhs.data_type()
161 );
162
163 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
164 } else {
165 let lhs = Datum::try_new(left)?;
167 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
168
169 match operator {
170 CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
171 CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
172 CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
173 CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
174 CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
175 CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
176 }
177 };
178 from_arrow_array_with_len(&array, left.len(), nullable)
179}
180
181pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
182 if lhs.is_null() | rhs.is_null() {
183 Scalar::null(DType::Bool(Nullability::Nullable))
184 } else {
185 let b = match operator {
186 CompareOperator::Eq => lhs == rhs,
187 CompareOperator::NotEq => lhs != rhs,
188 CompareOperator::Gt => lhs > rhs,
189 CompareOperator::Gte => lhs >= rhs,
190 CompareOperator::Lt => lhs < rhs,
191 CompareOperator::Lte => lhs <= rhs,
192 };
193
194 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use std::sync::Arc;
201
202 use rstest::rstest;
203 use vortex_buffer::buffer;
204
205 use crate::ArrayRef;
206 use crate::IntoArray;
207 use crate::ToCanonical;
208 use crate::arrays::BoolArray;
209 use crate::arrays::ConstantArray;
210 use crate::arrays::ListArray;
211 use crate::arrays::ListViewArray;
212 use crate::arrays::PrimitiveArray;
213 use crate::arrays::StructArray;
214 use crate::arrays::VarBinArray;
215 use crate::arrays::VarBinViewArray;
216 use crate::assert_arrays_eq;
217 use crate::builtins::ArrayBuiltins;
218 use crate::dtype::DType;
219 use crate::dtype::FieldName;
220 use crate::dtype::FieldNames;
221 use crate::dtype::Nullability;
222 use crate::dtype::PType;
223 use crate::scalar::Scalar;
224 use crate::scalar_fn::fns::operators::Operator;
225 use crate::test_harness::to_int_indices;
226 use crate::validity::Validity;
227
228 #[test]
229 fn test_bool_basic_comparisons() {
230 use vortex_buffer::BitBuffer;
231
232 let arr = BoolArray::new(
233 BitBuffer::from_iter([true, true, false, true, false]),
234 Validity::from_iter([false, true, true, true, true]),
235 );
236
237 let matches = arr
238 .to_array()
239 .binary(arr.to_array(), Operator::Eq)
240 .unwrap()
241 .to_bool();
242 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
243
244 let matches = arr
245 .to_array()
246 .binary(arr.to_array(), Operator::NotEq)
247 .unwrap()
248 .to_bool();
249 let empty: [u64; 0] = [];
250 assert_eq!(to_int_indices(matches).unwrap(), empty);
251
252 let other = BoolArray::new(
253 BitBuffer::from_iter([false, false, false, true, true]),
254 Validity::from_iter([false, true, true, true, true]),
255 );
256
257 let matches = arr
258 .to_array()
259 .binary(other.to_array(), Operator::Lte)
260 .unwrap()
261 .to_bool();
262 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
263
264 let matches = arr
265 .to_array()
266 .binary(other.to_array(), Operator::Lt)
267 .unwrap()
268 .to_bool();
269 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
270
271 let matches = other
272 .to_array()
273 .binary(arr.to_array(), Operator::Gte)
274 .unwrap()
275 .to_bool();
276 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
277
278 let matches = other
279 .to_array()
280 .binary(arr.to_array(), Operator::Gt)
281 .unwrap()
282 .to_bool();
283 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
284 }
285
286 #[test]
287 fn constant_compare() {
288 let left = ConstantArray::new(Scalar::from(2u32), 10);
289 let right = ConstantArray::new(Scalar::from(10u32), 10);
290
291 let result = left
292 .to_array()
293 .binary(right.to_array(), Operator::Gt)
294 .unwrap();
295 assert_eq!(result.len(), 10);
296 let scalar = result.scalar_at(0).unwrap();
297 assert_eq!(scalar.as_bool().value(), Some(false));
298 }
299
300 #[rstest]
301 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
302 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
303 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
304 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
305 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
306 let res = left.binary(right, Operator::Eq).unwrap();
307 let expected = BoolArray::from_iter([true, true]);
308 assert_arrays_eq!(res, expected);
309 }
310
311 #[ignore = "Arrow's ListView cannot be compared"]
312 #[test]
313 fn test_list_array_comparison() {
314 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
315 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
316 let list1 = ListArray::try_new(
317 values1.into_array(),
318 offsets1.into_array(),
319 Validity::NonNullable,
320 )
321 .unwrap();
322
323 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
324 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
325 let list2 = ListArray::try_new(
326 values2.into_array(),
327 offsets2.into_array(),
328 Validity::NonNullable,
329 )
330 .unwrap();
331
332 let result = list1
333 .to_array()
334 .binary(list2.to_array(), Operator::Eq)
335 .unwrap();
336 let expected = BoolArray::from_iter([true, true, false]);
337 assert_arrays_eq!(result, expected);
338
339 let result = list1
340 .to_array()
341 .binary(list2.to_array(), Operator::NotEq)
342 .unwrap();
343 let expected = BoolArray::from_iter([false, false, true]);
344 assert_arrays_eq!(result, expected);
345
346 let result = list1
347 .to_array()
348 .binary(list2.to_array(), Operator::Lt)
349 .unwrap();
350 let expected = BoolArray::from_iter([false, false, true]);
351 assert_arrays_eq!(result, expected);
352 }
353
354 #[ignore = "Arrow's ListView cannot be compared"]
355 #[test]
356 fn test_list_array_constant_comparison() {
357 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
358 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
359 let list = ListArray::try_new(
360 values.into_array(),
361 offsets.into_array(),
362 Validity::NonNullable,
363 )
364 .unwrap();
365
366 let list_scalar = Scalar::list(
367 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
368 vec![3i32.into(), 4i32.into()],
369 Nullability::NonNullable,
370 );
371 let constant = ConstantArray::new(list_scalar, 3);
372
373 let result = list
374 .to_array()
375 .binary(constant.to_array(), Operator::Eq)
376 .unwrap();
377 let expected = BoolArray::from_iter([false, true, false]);
378 assert_arrays_eq!(result, expected);
379 }
380
381 #[test]
382 fn test_struct_array_comparison() {
383 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
384 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
385
386 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
387 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
388
389 let struct1 = StructArray::from_fields(&[
390 ("bool_col", bool_field1.into_array()),
391 ("int_col", int_field1.into_array()),
392 ])
393 .unwrap();
394
395 let struct2 = StructArray::from_fields(&[
396 ("bool_col", bool_field2.into_array()),
397 ("int_col", int_field2.into_array()),
398 ])
399 .unwrap();
400
401 let result = struct1
402 .to_array()
403 .binary(struct2.to_array(), Operator::Eq)
404 .unwrap();
405 let expected = BoolArray::from_iter([true, true, false]);
406 assert_arrays_eq!(result, expected);
407
408 let result = struct1
409 .to_array()
410 .binary(struct2.to_array(), Operator::Gt)
411 .unwrap();
412 let expected = BoolArray::from_iter([false, false, true]);
413 assert_arrays_eq!(result, expected);
414 }
415
416 #[test]
417 fn test_empty_struct_compare() {
418 let empty1 = StructArray::try_new(
419 FieldNames::from(Vec::<FieldName>::new()),
420 Vec::new(),
421 5,
422 Validity::NonNullable,
423 )
424 .unwrap();
425
426 let empty2 = StructArray::try_new(
427 FieldNames::from(Vec::<FieldName>::new()),
428 Vec::new(),
429 5,
430 Validity::NonNullable,
431 )
432 .unwrap();
433
434 let result = empty1
435 .to_array()
436 .binary(empty2.to_array(), Operator::Eq)
437 .unwrap();
438 let expected = BoolArray::from_iter([true, true, true, true, true]);
439 assert_arrays_eq!(result, expected);
440 }
441
442 #[test]
443 fn test_empty_list() {
444 let list = ListViewArray::new(
445 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
446 buffer![0i32, 0i32, 0i32].into_array(),
447 buffer![0i32, 0i32, 0i32].into_array(),
448 Validity::AllValid,
449 );
450
451 let result = list
452 .to_array()
453 .binary(list.to_array(), Operator::Eq)
454 .unwrap();
455 assert!(result.scalar_at(0).unwrap().is_valid());
456 assert!(result.scalar_at(1).unwrap().is_valid());
457 assert!(result.scalar_at(2).unwrap().is_valid());
458 }
459}