1use core::fmt;
5use std::cmp::Ordering;
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9use arrow_array::BooleanArray;
10use arrow_buffer::NullBuffer;
11use arrow_ord::ord::make_comparator;
12use arrow_schema::SortOptions;
13use vortex_buffer::BitBuffer;
14use vortex_dtype::DType;
15use vortex_dtype::IntegerPType;
16use vortex_dtype::Nullability;
17use vortex_error::VortexResult;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::IntoArray;
22use crate::arrays::ScalarFnArray;
23use crate::expr::Binary;
24use crate::expr::ScalarFn;
25use crate::expr::operators;
26use crate::scalar::Scalar;
27
28pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
32 let expr_op: operators::Operator = operator.into();
33 Ok(ScalarFnArray::try_new(
34 ScalarFn::new(Binary, expr_op),
35 vec![left.to_array(), right.to_array()],
36 left.len(),
37 )?
38 .into_array())
39}
40
41#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
42pub enum Operator {
43 Eq,
45 NotEq,
47 Gt,
49 Gte,
51 Lt,
53 Lte,
55}
56
57impl Display for Operator {
58 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
59 let display = match &self {
60 Operator::Eq => "=",
61 Operator::NotEq => "!=",
62 Operator::Gt => ">",
63 Operator::Gte => ">=",
64 Operator::Lt => "<",
65 Operator::Lte => "<=",
66 };
67 Display::fmt(display, f)
68 }
69}
70
71impl Operator {
72 pub fn inverse(self) -> Self {
73 match self {
74 Operator::Eq => Operator::NotEq,
75 Operator::NotEq => Operator::Eq,
76 Operator::Gt => Operator::Lte,
77 Operator::Gte => Operator::Lt,
78 Operator::Lt => Operator::Gte,
79 Operator::Lte => Operator::Gt,
80 }
81 }
82
83 pub fn swap(self) -> Self {
85 match self {
86 Operator::Eq => Operator::Eq,
87 Operator::NotEq => Operator::NotEq,
88 Operator::Gt => Operator::Lt,
89 Operator::Gte => Operator::Lte,
90 Operator::Lt => Operator::Gt,
91 Operator::Lte => Operator::Gte,
92 }
93 }
94}
95
96pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BitBuffer
99where
100 P: IntegerPType,
101 I: Iterator<Item = P>,
102{
103 let cmp_fn = match op {
105 Operator::Eq | Operator::Lte => |v| v == P::zero(),
106 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
107 Operator::Gte => |_| true,
108 Operator::Lt => |_| false,
109 };
110
111 lengths.map(cmp_fn).collect()
112}
113
114pub(crate) fn compare_nested_arrow_arrays(
122 lhs: &dyn arrow_array::Array,
123 rhs: &dyn arrow_array::Array,
124 operator: Operator,
125) -> VortexResult<BooleanArray> {
126 let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
127
128 let cmp_fn = match operator {
129 Operator::Eq => Ordering::is_eq,
130 Operator::NotEq => Ordering::is_ne,
131 Operator::Gt => Ordering::is_gt,
132 Operator::Gte => Ordering::is_ge,
133 Operator::Lt => Ordering::is_lt,
134 Operator::Lte => Ordering::is_le,
135 };
136
137 let values = (0..lhs.len())
138 .map(|i| cmp_fn(compare_arrays_at(i, i)))
139 .collect();
140 let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
141
142 Ok(BooleanArray::new(values, nulls))
143}
144
145pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
146 if lhs.is_null() | rhs.is_null() {
147 Scalar::null(DType::Bool(Nullability::Nullable))
148 } else {
149 let b = match operator {
150 Operator::Eq => lhs == rhs,
151 Operator::NotEq => lhs != rhs,
152 Operator::Gt => lhs > rhs,
153 Operator::Gte => lhs >= rhs,
154 Operator::Lt => lhs < rhs,
155 Operator::Lte => lhs <= rhs,
156 };
157
158 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use rstest::rstest;
165 use vortex_buffer::buffer;
166 use vortex_dtype::FieldName;
167 use vortex_dtype::FieldNames;
168
169 use super::*;
170 use crate::ToCanonical;
171 use crate::arrays::BoolArray;
172 use crate::arrays::ConstantArray;
173 use crate::arrays::ListArray;
174 use crate::arrays::ListViewArray;
175 use crate::arrays::PrimitiveArray;
176 use crate::arrays::StructArray;
177 use crate::arrays::VarBinArray;
178 use crate::arrays::VarBinViewArray;
179 use crate::assert_arrays_eq;
180 use crate::test_harness::to_int_indices;
181 use crate::validity::Validity;
182
183 #[test]
184 fn test_bool_basic_comparisons() {
185 let arr = BoolArray::new(
186 BitBuffer::from_iter([true, true, false, true, false]),
187 Validity::from_iter([false, true, true, true, true]),
188 );
189
190 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
191 .unwrap()
192 .to_bool();
193
194 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
195
196 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
197 .unwrap()
198 .to_bool();
199 let empty: [u64; 0] = [];
200 assert_eq!(to_int_indices(matches).unwrap(), empty);
201
202 let other = BoolArray::new(
203 BitBuffer::from_iter([false, false, false, true, true]),
204 Validity::from_iter([false, true, true, true, true]),
205 );
206
207 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
208 .unwrap()
209 .to_bool();
210 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
211
212 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
213 .unwrap()
214 .to_bool();
215 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
216
217 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
218 .unwrap()
219 .to_bool();
220 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
221
222 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
223 .unwrap()
224 .to_bool();
225 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
226 }
227
228 #[test]
229 fn constant_compare() {
230 let left = ConstantArray::new(Scalar::from(2u32), 10);
231 let right = ConstantArray::new(Scalar::from(10u32), 10);
232
233 let result = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
234 assert_eq!(result.len(), 10);
235 let scalar = result.scalar_at(0).unwrap();
236 assert_eq!(scalar.as_bool().value(), Some(false));
237 }
238
239 #[rstest]
240 #[case(Operator::Eq, vec![false, false, false, true])]
241 #[case(Operator::NotEq, vec![true, true, true, false])]
242 #[case(Operator::Gt, vec![true, true, true, false])]
243 #[case(Operator::Gte, vec![true, true, true, true])]
244 #[case(Operator::Lt, vec![false, false, false, false])]
245 #[case(Operator::Lte, vec![false, false, false, true])]
246 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
247 let lengths: Vec<i32> = vec![1, 5, 7, 0];
248
249 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
250 assert_eq!(Vec::from_iter(output.iter()), expected);
251 }
252
253 #[rstest]
254 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
255 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
256 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
257 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
258 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
259 let res = compare(&left, &right, Operator::Eq).unwrap();
260 let expected = BoolArray::from_iter([true, true]);
261 assert_arrays_eq!(res, expected);
262 }
263
264 #[ignore = "Arrow's ListView cannot be compared"]
265 #[test]
266 fn test_list_array_comparison() {
267 let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
269 let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
270 let list1 = ListArray::try_new(
271 values1.into_array(),
272 offsets1.into_array(),
273 Validity::NonNullable,
274 )
275 .unwrap();
276
277 let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
278 let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
279 let list2 = ListArray::try_new(
280 values2.into_array(),
281 offsets2.into_array(),
282 Validity::NonNullable,
283 )
284 .unwrap();
285
286 let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap();
288 let expected = BoolArray::from_iter([true, true, false]);
289 assert_arrays_eq!(result, expected);
290
291 let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap();
293 let expected = BoolArray::from_iter([false, false, true]);
294 assert_arrays_eq!(result, expected);
295
296 let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap();
298 let expected = BoolArray::from_iter([false, false, true]);
299 assert_arrays_eq!(result, expected);
300 }
301
302 #[ignore = "Arrow's ListView cannot be compared"]
303 #[test]
304 fn test_list_array_constant_comparison() {
305 use std::sync::Arc;
306
307 use vortex_dtype::DType;
308 use vortex_dtype::PType;
309
310 let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
312 let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
313 let list = ListArray::try_new(
314 values.into_array(),
315 offsets.into_array(),
316 Validity::NonNullable,
317 )
318 .unwrap();
319
320 let list_scalar = Scalar::list(
322 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
323 vec![3i32.into(), 4i32.into()],
324 Nullability::NonNullable,
325 );
326 let constant = ConstantArray::new(list_scalar, 3);
327
328 let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap();
330 let expected = BoolArray::from_iter([false, true, false]);
331 assert_arrays_eq!(result, expected);
332 }
333
334 #[test]
335 fn test_struct_array_comparison() {
336 let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
338 let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
339
340 let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
341 let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
342
343 let struct1 = StructArray::from_fields(&[
344 ("bool_col", bool_field1.into_array()),
345 ("int_col", int_field1.into_array()),
346 ])
347 .unwrap();
348
349 let struct2 = StructArray::from_fields(&[
350 ("bool_col", bool_field2.into_array()),
351 ("int_col", int_field2.into_array()),
352 ])
353 .unwrap();
354
355 let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap();
357 let expected = BoolArray::from_iter([true, true, false]);
358 assert_arrays_eq!(result, expected);
359
360 let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap();
362 let expected = BoolArray::from_iter([false, false, true]);
363 assert_arrays_eq!(result, expected);
364 }
365
366 #[test]
367 fn test_empty_struct_compare() {
368 let empty1 = StructArray::try_new(
369 FieldNames::from(Vec::<FieldName>::new()),
370 Vec::new(),
371 5,
372 Validity::NonNullable,
373 )
374 .unwrap();
375
376 let empty2 = StructArray::try_new(
377 FieldNames::from(Vec::<FieldName>::new()),
378 Vec::new(),
379 5,
380 Validity::NonNullable,
381 )
382 .unwrap();
383
384 let result = compare(empty1.as_ref(), empty2.as_ref(), Operator::Eq).unwrap();
385 let expected = BoolArray::from_iter([true, true, true, true, true]);
386 assert_arrays_eq!(result, expected);
387 }
388
389 #[test]
390 fn test_empty_list() {
391 let list = ListViewArray::new(
392 BoolArray::from_iter(Vec::<bool>::new()).into_array(),
393 buffer![0i32, 0i32, 0i32].into_array(),
394 buffer![0i32, 0i32, 0i32].into_array(),
395 Validity::AllValid,
396 );
397
398 let result = compare(list.as_ref(), list.as_ref(), Operator::Eq).unwrap();
400 assert!(result.scalar_at(0).unwrap().is_valid());
401 assert!(result.scalar_at(1).unwrap().is_valid());
402 assert!(result.scalar_at(2).unwrap().is_valid());
403 }
404}