vortex_array/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use core::fmt;
5use std::any::Any;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::sync::LazyLock;
9
10use arcref::ArcRef;
11use arrow_array::BooleanArray;
12use arrow_buffer::NullBuffer;
13use arrow_ord::cmp;
14use arrow_ord::ord::make_comparator;
15use arrow_schema::SortOptions;
16use vortex_buffer::BitBuffer;
17use vortex_dtype::DType;
18use vortex_dtype::IntegerPType;
19use vortex_dtype::Nullability;
20use vortex_error::VortexError;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_error::vortex_err;
25use vortex_scalar::Scalar;
26
27use crate::Array;
28use crate::ArrayRef;
29use crate::Canonical;
30use crate::IntoArray;
31use crate::arrays::ConstantArray;
32use crate::arrow::Datum;
33use crate::arrow::IntoArrowArray;
34use crate::arrow::from_arrow_array_with_len;
35use crate::compute::ComputeFn;
36use crate::compute::ComputeFnVTable;
37use crate::compute::InvocationArgs;
38use crate::compute::Kernel;
39use crate::compute::Options;
40use crate::compute::Output;
41use crate::vtable::VTable;
42
43static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
44    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
45    for kernel in inventory::iter::<CompareKernelRef> {
46        compute.register_kernel(kernel.0.clone());
47    }
48    compute
49});
50
51pub(crate) fn warm_up_vtable() -> usize {
52    COMPARE_FN.kernels().len()
53}
54
55/// Compares two arrays and returns a new boolean array with the result of the comparison.
56/// Or, returns None if comparison is not supported for these arrays.
57pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
58    COMPARE_FN
59        .invoke(&InvocationArgs {
60            inputs: &[left.into(), right.into()],
61            options: &operator,
62        })?
63        .unwrap_array()
64}
65
66#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
67pub enum Operator {
68    /// Equality (`=`)
69    Eq,
70    /// Inequality (`!=`)
71    NotEq,
72    /// Greater than (`>`)
73    Gt,
74    /// Greater than or equal (`>=`)
75    Gte,
76    /// Less than (`<`)
77    Lt,
78    /// Less than or equal (`<=`)
79    Lte,
80}
81
82impl Display for Operator {
83    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
84        let display = match &self {
85            Operator::Eq => "=",
86            Operator::NotEq => "!=",
87            Operator::Gt => ">",
88            Operator::Gte => ">=",
89            Operator::Lt => "<",
90            Operator::Lte => "<=",
91        };
92        Display::fmt(display, f)
93    }
94}
95
96impl Operator {
97    pub fn inverse(self) -> Self {
98        match self {
99            Operator::Eq => Operator::NotEq,
100            Operator::NotEq => Operator::Eq,
101            Operator::Gt => Operator::Lte,
102            Operator::Gte => Operator::Lt,
103            Operator::Lt => Operator::Gte,
104            Operator::Lte => Operator::Gt,
105        }
106    }
107
108    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
109    pub fn swap(self) -> Self {
110        match self {
111            Operator::Eq => Operator::Eq,
112            Operator::NotEq => Operator::NotEq,
113            Operator::Gt => Operator::Lt,
114            Operator::Gte => Operator::Lte,
115            Operator::Lt => Operator::Gt,
116            Operator::Lte => Operator::Gte,
117        }
118    }
119}
120
121pub struct CompareKernelRef(ArcRef<dyn Kernel>);
122inventory::collect!(CompareKernelRef);
123
124pub trait CompareKernel: VTable {
125    fn compare(
126        &self,
127        lhs: &Self::Array,
128        rhs: &dyn Array,
129        operator: Operator,
130    ) -> VortexResult<Option<ArrayRef>>;
131}
132
133#[derive(Debug)]
134pub struct CompareKernelAdapter<V: VTable>(pub V);
135
136impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
137    pub const fn lift(&'static self) -> CompareKernelRef {
138        CompareKernelRef(ArcRef::new_ref(self))
139    }
140}
141
142impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
143    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
144        let inputs = CompareArgs::try_from(args)?;
145        let Some(array) = inputs.lhs.as_opt::<V>() else {
146            return Ok(None);
147        };
148        Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
149    }
150}
151
152struct Compare;
153
154impl ComputeFnVTable for Compare {
155    fn invoke(
156        &self,
157        args: &InvocationArgs,
158        kernels: &[ArcRef<dyn Kernel>],
159    ) -> VortexResult<Output> {
160        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
161
162        let return_dtype = self.return_dtype(args)?;
163
164        if lhs.is_empty() {
165            return Ok(Canonical::empty(&return_dtype).into_array().into());
166        }
167
168        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
169        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
170        if left_constant_null || right_constant_null {
171            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
172                .into_array()
173                .into());
174        }
175
176        let right_is_constant = rhs.is_constant();
177
178        // Always try to put constants on the right-hand side so encodings can optimise themselves.
179        if lhs.is_constant() && !right_is_constant {
180            return Ok(compare(rhs, lhs, operator.swap())?.into());
181        }
182
183        // First try lhs op rhs, then invert and try again.
184        for kernel in kernels {
185            if let Some(output) = kernel.invoke(args)? {
186                return Ok(output);
187            }
188        }
189        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
190            return Ok(output);
191        }
192
193        // Try inverting the operator and swapping the arguments
194        let inverted_args = InvocationArgs {
195            inputs: &[rhs.into(), lhs.into()],
196            options: &operator.swap(),
197        };
198        for kernel in kernels {
199            if let Some(output) = kernel.invoke(&inverted_args)? {
200                return Ok(output);
201            }
202        }
203        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
204            return Ok(output);
205        }
206
207        // Only log missing compare implementation if there's possibly better one than arrow,
208        // i.e. lhs isn't arrow or rhs isn't arrow or constant
209        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
210            log::debug!(
211                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
212                lhs.encoding_id(),
213                rhs.encoding_id(),
214                operator,
215            );
216        }
217
218        // Fallback to arrow on canonical types
219        Ok(arrow_compare(lhs, rhs, operator)?.into())
220    }
221
222    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
223        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
224
225        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
226            if lhs.dtype().is_float() && rhs.dtype().is_float() {
227                vortex_bail!(
228                    "Cannot compare different floating-point types ({}, {}). Consider using cast.",
229                    lhs.dtype(),
230                    rhs.dtype(),
231                );
232            }
233            if lhs.dtype().is_int() && rhs.dtype().is_int() {
234                vortex_bail!(
235                    "Cannot compare different fixed-width types ({}, {}). Consider using cast.",
236                    lhs.dtype(),
237                    rhs.dtype()
238                );
239            }
240            vortex_bail!(
241                "Cannot compare different DTypes {} and {}",
242                lhs.dtype(),
243                rhs.dtype()
244            );
245        }
246
247        Ok(DType::Bool(
248            lhs.dtype().nullability() | rhs.dtype().nullability(),
249        ))
250    }
251
252    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
253        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
254        if lhs.len() != rhs.len() {
255            vortex_bail!(
256                "Compare operations only support arrays of the same length, got {} and {}",
257                lhs.len(),
258                rhs.len()
259            );
260        }
261        Ok(lhs.len())
262    }
263
264    fn is_elementwise(&self) -> bool {
265        true
266    }
267}
268
269struct CompareArgs<'a> {
270    lhs: &'a dyn Array,
271    rhs: &'a dyn Array,
272    operator: Operator,
273}
274
275impl Options for Operator {
276    fn as_any(&self) -> &dyn Any {
277        self
278    }
279}
280
281impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
282    type Error = VortexError;
283
284    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
285        if value.inputs.len() != 2 {
286            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
287        }
288        let lhs = value.inputs[0]
289            .array()
290            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
291        let rhs = value.inputs[1]
292            .array()
293            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
294        let operator = *value
295            .options
296            .as_any()
297            .downcast_ref::<Operator>()
298            .vortex_expect("Expected options to be an operator");
299
300        Ok(CompareArgs { lhs, rhs, operator })
301    }
302}
303
304/// Helper function to compare empty values with arrays that have external value length information
305/// like `VarBin`.
306pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BitBuffer
307where
308    P: IntegerPType,
309    I: Iterator<Item = P>,
310{
311    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
312    let cmp_fn = match op {
313        Operator::Eq | Operator::Lte => |v| v == P::zero(),
314        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
315        Operator::Gte => |_| true,
316        Operator::Lt => |_| false,
317    };
318
319    lengths.map(cmp_fn).collect()
320}
321
322/// Implementation of `CompareFn` using the Arrow crate.
323fn arrow_compare(
324    left: &dyn Array,
325    right: &dyn Array,
326    operator: Operator,
327) -> VortexResult<ArrayRef> {
328    assert_eq!(left.len(), right.len());
329
330    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
331
332    let array = if left.dtype().is_nested() || right.dtype().is_nested() {
333        let rhs = right.to_array().into_arrow_preferred()?;
334        let lhs = left.to_array().into_arrow(rhs.data_type())?;
335
336        assert!(
337            lhs.data_type().equals_datatype(rhs.data_type()),
338            "lhs data_type: {}, rhs data_type: {}",
339            lhs.data_type(),
340            rhs.data_type()
341        );
342
343        let cmp = make_comparator(lhs.as_ref(), rhs.as_ref(), SortOptions::default())?;
344        let len = left.len();
345        let values = (0..len)
346            .map(|i| {
347                let cmp = cmp(i, i);
348                match operator {
349                    Operator::Eq => cmp.is_eq(),
350                    Operator::NotEq => cmp.is_ne(),
351                    Operator::Gt => cmp.is_gt(),
352                    Operator::Gte => cmp.is_gt() || cmp.is_eq(),
353                    Operator::Lt => cmp.is_lt(),
354                    Operator::Lte => cmp.is_lt() || cmp.is_eq(),
355                }
356            })
357            .collect();
358        let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
359        BooleanArray::new(values, nulls)
360    } else {
361        let lhs = Datum::try_new(left)?;
362        let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
363
364        match operator {
365            Operator::Eq => cmp::eq(&lhs, &rhs)?,
366            Operator::NotEq => cmp::neq(&lhs, &rhs)?,
367            Operator::Gt => cmp::gt(&lhs, &rhs)?,
368            Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
369            Operator::Lt => cmp::lt(&lhs, &rhs)?,
370            Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
371        }
372    };
373    Ok(from_arrow_array_with_len(&array, left.len(), nullable))
374}
375
376pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
377    if lhs.is_null() | rhs.is_null() {
378        Scalar::null(DType::Bool(Nullability::Nullable))
379    } else {
380        let b = match operator {
381            Operator::Eq => lhs == rhs,
382            Operator::NotEq => lhs != rhs,
383            Operator::Gt => lhs > rhs,
384            Operator::Gte => lhs >= rhs,
385            Operator::Lt => lhs < rhs,
386            Operator::Lte => lhs <= rhs,
387        };
388
389        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use rstest::rstest;
396    use vortex_buffer::buffer;
397    use vortex_dtype::FieldName;
398    use vortex_dtype::FieldNames;
399
400    use super::*;
401    use crate::ToCanonical;
402    use crate::arrays::BoolArray;
403    use crate::arrays::ConstantArray;
404    use crate::arrays::ListArray;
405    use crate::arrays::ListViewArray;
406    use crate::arrays::PrimitiveArray;
407    use crate::arrays::StructArray;
408    use crate::arrays::VarBinArray;
409    use crate::arrays::VarBinViewArray;
410    use crate::expr::get_item;
411    use crate::expr::lt;
412    use crate::expr::root;
413    use crate::test_harness::to_int_indices;
414    use crate::validity::Validity;
415
416    #[test]
417    fn test_bool_basic_comparisons() {
418        let arr = BoolArray::from_bit_buffer(
419            BitBuffer::from_iter([true, true, false, true, false]),
420            Validity::from_iter([false, true, true, true, true]),
421        );
422
423        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
424            .unwrap()
425            .to_bool();
426
427        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
428
429        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
430            .unwrap()
431            .to_bool();
432        let empty: [u64; 0] = [];
433        assert_eq!(to_int_indices(matches).unwrap(), empty);
434
435        let other = BoolArray::from_bit_buffer(
436            BitBuffer::from_iter([false, false, false, true, true]),
437            Validity::from_iter([false, true, true, true, true]),
438        );
439
440        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
441            .unwrap()
442            .to_bool();
443        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
444
445        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
446            .unwrap()
447            .to_bool();
448        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
449
450        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
451            .unwrap()
452            .to_bool();
453        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
454
455        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
456            .unwrap()
457            .to_bool();
458        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
459    }
460
461    #[test]
462    fn constant_compare() {
463        let left = ConstantArray::new(Scalar::from(2u32), 10);
464        let right = ConstantArray::new(Scalar::from(10u32), 10);
465
466        let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
467        let res = compare.as_constant().unwrap();
468        assert_eq!(res.as_bool().value(), Some(false));
469        assert_eq!(compare.len(), 10);
470
471        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
472        let res = compare.as_constant().unwrap();
473        assert_eq!(res.as_bool().value(), Some(false));
474        assert_eq!(compare.len(), 10);
475    }
476
477    #[rstest]
478    #[case(Operator::Eq, vec![false, false, false, true])]
479    #[case(Operator::NotEq, vec![true, true, true, false])]
480    #[case(Operator::Gt, vec![true, true, true, false])]
481    #[case(Operator::Gte, vec![true, true, true, true])]
482    #[case(Operator::Lt, vec![false, false, false, false])]
483    #[case(Operator::Lte, vec![false, false, false, true])]
484    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
485        let lengths: Vec<i32> = vec![1, 5, 7, 0];
486
487        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
488        assert_eq!(Vec::from_iter(output.iter()), expected);
489    }
490
491    #[rstest]
492    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
493    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
494    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
495    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
496    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
497        let res = compare(&left, &right, Operator::Eq).unwrap();
498        assert_eq!(res.to_bool().bit_buffer().true_count(), left.len());
499    }
500
501    #[ignore = "Arrow's ListView cannot be compared"]
502    #[test]
503    fn test_list_array_comparison() {
504        // Create two simple list arrays with integers
505        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
506        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
507        let list1 = ListArray::try_new(
508            values1.into_array(),
509            offsets1.into_array(),
510            Validity::NonNullable,
511        )
512        .unwrap();
513
514        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
515        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
516        let list2 = ListArray::try_new(
517            values2.into_array(),
518            offsets2.into_array(),
519            Validity::NonNullable,
520        )
521        .unwrap();
522
523        // Test equality - first two lists should be equal, third should be different
524        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap();
525        let bool_result = result.to_bool();
526        assert!(bool_result.bit_buffer().value(0)); // [1,2] == [1,2]
527        assert!(bool_result.bit_buffer().value(1)); // [3,4] == [3,4]
528        assert!(!bool_result.bit_buffer().value(2)); // [5,6] != [7,8]
529
530        // Test inequality
531        let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap();
532        let bool_result = result.to_bool();
533        assert!(!bool_result.bit_buffer().value(0));
534        assert!(!bool_result.bit_buffer().value(1));
535        assert!(bool_result.bit_buffer().value(2));
536
537        // Test less than
538        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap();
539        let bool_result = result.to_bool();
540        assert!(!bool_result.bit_buffer().value(0)); // [1,2] < [1,2] = false
541        assert!(!bool_result.bit_buffer().value(1)); // [3,4] < [3,4] = false
542        assert!(bool_result.bit_buffer().value(2)); // [5,6] < [7,8] = true
543    }
544
545    #[ignore = "Arrow's ListView cannot be compared"]
546    #[test]
547    fn test_list_array_constant_comparison() {
548        use std::sync::Arc;
549
550        use vortex_dtype::DType;
551        use vortex_dtype::PType;
552
553        // Create a list array
554        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
555        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
556        let list = ListArray::try_new(
557            values.into_array(),
558            offsets.into_array(),
559            Validity::NonNullable,
560        )
561        .unwrap();
562
563        // Create a constant list scalar [3,4] that will be broadcasted
564        let list_scalar = Scalar::list(
565            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
566            vec![3i32.into(), 4i32.into()],
567            Nullability::NonNullable,
568        );
569        let constant = ConstantArray::new(list_scalar, 3);
570
571        // Compare list with constant - all should be compared to [3,4]
572        let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap();
573        let bool_result = result.to_bool();
574        assert!(!bool_result.bit_buffer().value(0)); // [1,2] != [3,4]
575        assert!(bool_result.bit_buffer().value(1)); // [3,4] == [3,4]
576        assert!(!bool_result.bit_buffer().value(2)); // [5,6] != [3,4]
577    }
578
579    #[test]
580    fn test_struct_array_comparison() {
581        // Create two struct arrays with bool and int fields
582        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
583        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
584
585        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
586        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
587
588        let struct1 = StructArray::from_fields(&[
589            ("bool_col", bool_field1.into_array()),
590            ("int_col", int_field1.into_array()),
591        ])
592        .unwrap();
593
594        let struct2 = StructArray::from_fields(&[
595            ("bool_col", bool_field2.into_array()),
596            ("int_col", int_field2.into_array()),
597        ])
598        .unwrap();
599
600        // Test equality
601        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap();
602        let bool_result = result.to_bool();
603        assert!(bool_result.bit_buffer().value(0)); // {true, 1} == {true, 1}
604        assert!(bool_result.bit_buffer().value(1)); // {false, 2} == {false, 2}
605        assert!(!bool_result.bit_buffer().value(2)); // {true, 3} != {false, 4}
606
607        // Test greater than
608        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap();
609        let bool_result = result.to_bool();
610        assert!(!bool_result.bit_buffer().value(0)); // {true, 1} > {true, 1} = false
611        assert!(!bool_result.bit_buffer().value(1)); // {false, 2} > {false, 2} = false
612        assert!(bool_result.bit_buffer().value(2)); // {true, 3} > {false, 4} = true (bool field takes precedence)
613    }
614
615    #[test]
616    fn test_empty_struct_compare() {
617        let empty1 = StructArray::try_new(
618            FieldNames::from(Vec::<FieldName>::new()),
619            Vec::new(),
620            5,
621            Validity::NonNullable,
622        )
623        .unwrap();
624
625        let empty2 = StructArray::try_new(
626            FieldNames::from(Vec::<FieldName>::new()),
627            Vec::new(),
628            5,
629            Validity::NonNullable,
630        )
631        .unwrap();
632
633        let result = compare(empty1.as_ref(), empty2.as_ref(), Operator::Eq).unwrap();
634        let result = result.to_bool();
635
636        for idx in 0..5 {
637            assert!(result.bit_buffer().value(idx));
638        }
639    }
640
641    #[test]
642    fn test_empty_list() {
643        let list = ListViewArray::new(
644            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
645            buffer![0i32, 0i32, 0i32].into_array(),
646            buffer![0i32, 0i32, 0i32].into_array(),
647            Validity::AllValid,
648        );
649
650        // Compare two lists together
651        let result = compare(list.as_ref(), list.as_ref(), Operator::Eq).unwrap();
652        assert!(result.scalar_at(0).is_valid());
653        assert!(result.scalar_at(1).is_valid());
654        assert!(result.scalar_at(2).is_valid());
655    }
656
657    #[test]
658    fn test_different_floats_error_messages() {
659        let result = compare(
660            &buffer![0.0f32].into_array(),
661            &buffer![0.0f64].into_array(),
662            Operator::Lt,
663        );
664        assert!(result.as_ref().is_err_and(|err| {
665            err.to_string()
666                .contains("Cannot compare different floating-point types")
667        }));
668
669        let expr = lt(get_item("l", root()), get_item("r", root()));
670        let result = expr.evaluate(
671            &StructArray::from_fields(&[
672                ("l", buffer![0.0f32].into_array()),
673                ("r", buffer![0.0f64].into_array()),
674            ])
675            .unwrap()
676            .into_array(),
677        );
678        assert!(result.as_ref().is_err_and(|err| {
679            err.to_string()
680                .contains("Cannot compare different floating-point types")
681        }));
682    }
683
684    #[test]
685    fn test_different_ints_error_messages() {
686        let result = compare(
687            &buffer![0u8].into_array(),
688            &buffer![0u16].into_array(),
689            Operator::Lt,
690        );
691        assert!(result.as_ref().is_err_and(|err| {
692            err.to_string()
693                .contains("Cannot compare different fixed-width types")
694        }));
695
696        let expr = lt(get_item("l", root()), get_item("r", root()));
697        let result = expr.evaluate(
698            &StructArray::from_fields(&[
699                ("l", buffer![0u8].into_array()),
700                ("r", buffer![0u16].into_array()),
701            ])
702            .unwrap()
703            .into_array(),
704        );
705        assert!(result.as_ref().is_err_and(|err| {
706            err.to_string()
707                .contains("Cannot compare different fixed-width types")
708        }));
709    }
710}