Skip to main content

vortex_array/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use core::fmt;
5use std::cmp::Ordering;
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9use arrow_array::BooleanArray;
10use arrow_buffer::NullBuffer;
11use arrow_ord::ord::make_comparator;
12use arrow_schema::SortOptions;
13use vortex_buffer::BitBuffer;
14use vortex_dtype::DType;
15use vortex_dtype::IntegerPType;
16use vortex_dtype::Nullability;
17use vortex_error::VortexResult;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::IntoArray;
22use crate::arrays::ScalarFnArray;
23use crate::expr::Binary;
24use crate::expr::ScalarFn;
25use crate::expr::operators;
26use crate::scalar::Scalar;
27
28/// Compares two arrays and returns a new boolean array with the result of the comparison.
29///
30/// The returned array is lazy (a [`ScalarFnArray`]) and will be evaluated on demand.
31pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
32    let expr_op: operators::Operator = operator.into();
33    Ok(ScalarFnArray::try_new(
34        ScalarFn::new(Binary, expr_op),
35        vec![left.to_array(), right.to_array()],
36        left.len(),
37    )?
38    .into_array())
39}
40
41#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
42pub enum Operator {
43    /// Equality (`=`)
44    Eq,
45    /// Inequality (`!=`)
46    NotEq,
47    /// Greater than (`>`)
48    Gt,
49    /// Greater than or equal (`>=`)
50    Gte,
51    /// Less than (`<`)
52    Lt,
53    /// Less than or equal (`<=`)
54    Lte,
55}
56
57impl Display for Operator {
58    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
59        let display = match &self {
60            Operator::Eq => "=",
61            Operator::NotEq => "!=",
62            Operator::Gt => ">",
63            Operator::Gte => ">=",
64            Operator::Lt => "<",
65            Operator::Lte => "<=",
66        };
67        Display::fmt(display, f)
68    }
69}
70
71impl Operator {
72    pub fn inverse(self) -> Self {
73        match self {
74            Operator::Eq => Operator::NotEq,
75            Operator::NotEq => Operator::Eq,
76            Operator::Gt => Operator::Lte,
77            Operator::Gte => Operator::Lt,
78            Operator::Lt => Operator::Gte,
79            Operator::Lte => Operator::Gt,
80        }
81    }
82
83    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
84    pub fn swap(self) -> Self {
85        match self {
86            Operator::Eq => Operator::Eq,
87            Operator::NotEq => Operator::NotEq,
88            Operator::Gt => Operator::Lt,
89            Operator::Gte => Operator::Lte,
90            Operator::Lt => Operator::Gt,
91            Operator::Lte => Operator::Gte,
92        }
93    }
94}
95
96/// Helper function to compare empty values with arrays that have external value length information
97/// like `VarBin`.
98pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BitBuffer
99where
100    P: IntegerPType,
101    I: Iterator<Item = P>,
102{
103    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
104    let cmp_fn = match op {
105        Operator::Eq | Operator::Lte => |v| v == P::zero(),
106        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
107        Operator::Gte => |_| true,
108        Operator::Lt => |_| false,
109    };
110
111    lengths.map(cmp_fn).collect()
112}
113
114/// Compare two Arrow arrays element-wise using [`make_comparator`].
115///
116/// This function is required for nested types (Struct, List, FixedSizeList) because Arrow's
117/// vectorized comparison kernels ([`cmp::eq`], [`cmp::neq`], etc.) do not support them.
118///
119/// The vectorized kernels are faster but only work on primitive types, so for non-nested types,
120/// prefer using the vectorized kernels directly for better performance.
121pub(crate) fn compare_nested_arrow_arrays(
122    lhs: &dyn arrow_array::Array,
123    rhs: &dyn arrow_array::Array,
124    operator: Operator,
125) -> VortexResult<BooleanArray> {
126    let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
127
128    let cmp_fn = match operator {
129        Operator::Eq => Ordering::is_eq,
130        Operator::NotEq => Ordering::is_ne,
131        Operator::Gt => Ordering::is_gt,
132        Operator::Gte => Ordering::is_ge,
133        Operator::Lt => Ordering::is_lt,
134        Operator::Lte => Ordering::is_le,
135    };
136
137    let values = (0..lhs.len())
138        .map(|i| cmp_fn(compare_arrays_at(i, i)))
139        .collect();
140    let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
141
142    Ok(BooleanArray::new(values, nulls))
143}
144
145pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
146    if lhs.is_null() | rhs.is_null() {
147        Scalar::null(DType::Bool(Nullability::Nullable))
148    } else {
149        let b = match operator {
150            Operator::Eq => lhs == rhs,
151            Operator::NotEq => lhs != rhs,
152            Operator::Gt => lhs > rhs,
153            Operator::Gte => lhs >= rhs,
154            Operator::Lt => lhs < rhs,
155            Operator::Lte => lhs <= rhs,
156        };
157
158        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use rstest::rstest;
165    use vortex_buffer::buffer;
166    use vortex_dtype::FieldName;
167    use vortex_dtype::FieldNames;
168
169    use super::*;
170    use crate::ToCanonical;
171    use crate::arrays::BoolArray;
172    use crate::arrays::ConstantArray;
173    use crate::arrays::ListArray;
174    use crate::arrays::ListViewArray;
175    use crate::arrays::PrimitiveArray;
176    use crate::arrays::StructArray;
177    use crate::arrays::VarBinArray;
178    use crate::arrays::VarBinViewArray;
179    use crate::assert_arrays_eq;
180    use crate::test_harness::to_int_indices;
181    use crate::validity::Validity;
182
183    #[test]
184    fn test_bool_basic_comparisons() {
185        let arr = BoolArray::new(
186            BitBuffer::from_iter([true, true, false, true, false]),
187            Validity::from_iter([false, true, true, true, true]),
188        );
189
190        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
191            .unwrap()
192            .to_bool();
193
194        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
195
196        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
197            .unwrap()
198            .to_bool();
199        let empty: [u64; 0] = [];
200        assert_eq!(to_int_indices(matches).unwrap(), empty);
201
202        let other = BoolArray::new(
203            BitBuffer::from_iter([false, false, false, true, true]),
204            Validity::from_iter([false, true, true, true, true]),
205        );
206
207        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
208            .unwrap()
209            .to_bool();
210        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
211
212        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
213            .unwrap()
214            .to_bool();
215        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
216
217        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
218            .unwrap()
219            .to_bool();
220        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
221
222        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
223            .unwrap()
224            .to_bool();
225        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
226    }
227
228    #[test]
229    fn constant_compare() {
230        let left = ConstantArray::new(Scalar::from(2u32), 10);
231        let right = ConstantArray::new(Scalar::from(10u32), 10);
232
233        let result = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
234        assert_eq!(result.len(), 10);
235        let scalar = result.scalar_at(0).unwrap();
236        assert_eq!(scalar.as_bool().value(), Some(false));
237    }
238
239    #[rstest]
240    #[case(Operator::Eq, vec![false, false, false, true])]
241    #[case(Operator::NotEq, vec![true, true, true, false])]
242    #[case(Operator::Gt, vec![true, true, true, false])]
243    #[case(Operator::Gte, vec![true, true, true, true])]
244    #[case(Operator::Lt, vec![false, false, false, false])]
245    #[case(Operator::Lte, vec![false, false, false, true])]
246    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
247        let lengths: Vec<i32> = vec![1, 5, 7, 0];
248
249        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
250        assert_eq!(Vec::from_iter(output.iter()), expected);
251    }
252
253    #[rstest]
254    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
255    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
256    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
257    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
258    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
259        let res = compare(&left, &right, Operator::Eq).unwrap();
260        let expected = BoolArray::from_iter([true, true]);
261        assert_arrays_eq!(res, expected);
262    }
263
264    #[ignore = "Arrow's ListView cannot be compared"]
265    #[test]
266    fn test_list_array_comparison() {
267        // Create two simple list arrays with integers
268        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
269        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
270        let list1 = ListArray::try_new(
271            values1.into_array(),
272            offsets1.into_array(),
273            Validity::NonNullable,
274        )
275        .unwrap();
276
277        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
278        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
279        let list2 = ListArray::try_new(
280            values2.into_array(),
281            offsets2.into_array(),
282            Validity::NonNullable,
283        )
284        .unwrap();
285
286        // Test equality - first two lists should be equal, third should be different
287        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap();
288        let expected = BoolArray::from_iter([true, true, false]);
289        assert_arrays_eq!(result, expected);
290
291        // Test inequality
292        let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap();
293        let expected = BoolArray::from_iter([false, false, true]);
294        assert_arrays_eq!(result, expected);
295
296        // Test less than
297        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap();
298        let expected = BoolArray::from_iter([false, false, true]);
299        assert_arrays_eq!(result, expected);
300    }
301
302    #[ignore = "Arrow's ListView cannot be compared"]
303    #[test]
304    fn test_list_array_constant_comparison() {
305        use std::sync::Arc;
306
307        use vortex_dtype::DType;
308        use vortex_dtype::PType;
309
310        // Create a list array
311        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
312        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
313        let list = ListArray::try_new(
314            values.into_array(),
315            offsets.into_array(),
316            Validity::NonNullable,
317        )
318        .unwrap();
319
320        // Create a constant list scalar [3,4] that will be broadcasted
321        let list_scalar = Scalar::list(
322            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
323            vec![3i32.into(), 4i32.into()],
324            Nullability::NonNullable,
325        );
326        let constant = ConstantArray::new(list_scalar, 3);
327
328        // Compare list with constant - all should be compared to [3,4]
329        let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap();
330        let expected = BoolArray::from_iter([false, true, false]);
331        assert_arrays_eq!(result, expected);
332    }
333
334    #[test]
335    fn test_struct_array_comparison() {
336        // Create two struct arrays with bool and int fields
337        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
338        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
339
340        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
341        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
342
343        let struct1 = StructArray::from_fields(&[
344            ("bool_col", bool_field1.into_array()),
345            ("int_col", int_field1.into_array()),
346        ])
347        .unwrap();
348
349        let struct2 = StructArray::from_fields(&[
350            ("bool_col", bool_field2.into_array()),
351            ("int_col", int_field2.into_array()),
352        ])
353        .unwrap();
354
355        // Test equality
356        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap();
357        let expected = BoolArray::from_iter([true, true, false]);
358        assert_arrays_eq!(result, expected);
359
360        // Test greater than
361        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap();
362        let expected = BoolArray::from_iter([false, false, true]);
363        assert_arrays_eq!(result, expected);
364    }
365
366    #[test]
367    fn test_empty_struct_compare() {
368        let empty1 = StructArray::try_new(
369            FieldNames::from(Vec::<FieldName>::new()),
370            Vec::new(),
371            5,
372            Validity::NonNullable,
373        )
374        .unwrap();
375
376        let empty2 = StructArray::try_new(
377            FieldNames::from(Vec::<FieldName>::new()),
378            Vec::new(),
379            5,
380            Validity::NonNullable,
381        )
382        .unwrap();
383
384        let result = compare(empty1.as_ref(), empty2.as_ref(), Operator::Eq).unwrap();
385        let expected = BoolArray::from_iter([true, true, true, true, true]);
386        assert_arrays_eq!(result, expected);
387    }
388
389    #[test]
390    fn test_empty_list() {
391        let list = ListViewArray::new(
392            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
393            buffer![0i32, 0i32, 0i32].into_array(),
394            buffer![0i32, 0i32, 0i32].into_array(),
395            Validity::AllValid,
396        );
397
398        // Compare two lists together
399        let result = compare(list.as_ref(), list.as_ref(), Operator::Eq).unwrap();
400        assert!(result.scalar_at(0).unwrap().is_valid());
401        assert!(result.scalar_at(1).unwrap().is_valid());
402        assert!(result.scalar_at(2).unwrap().is_valid());
403    }
404}