vortex_array/arrays/masked/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::arrays::{BoolArray, MaskedArray, MaskedVTable};
7use crate::canonical::ToCanonical;
8use crate::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
9use crate::vtable::ValidityHelper;
10use crate::{Array, ArrayRef, IntoArray, register_kernel};
11
12impl CompareKernel for MaskedVTable {
13    fn compare(
14        &self,
15        lhs: &MaskedArray,
16        rhs: &dyn Array,
17        operator: Operator,
18    ) -> VortexResult<Option<ArrayRef>> {
19        // Compare the child arrays
20        let compare_result = compare(&lhs.child, rhs, operator)?;
21
22        // Get the boolean buffer from the comparison result
23        let bool_array = compare_result.to_bool();
24        let combined_validity = bool_array.validity().clone().and(lhs.validity().clone());
25
26        // Return a plain BoolArray with the combined validity
27        Ok(Some(
28            BoolArray::from_bool_buffer(bool_array.boolean_buffer().clone(), combined_validity)
29                .into_array(),
30        ))
31    }
32}
33
34register_kernel!(CompareKernelAdapter(MaskedVTable).lift());
35
36#[cfg(test)]
37mod tests {
38    use vortex_dtype::Nullability;
39    use vortex_mask::Mask;
40    use vortex_scalar::Scalar;
41
42    use crate::arrays::{ConstantArray, MaskedArray, PrimitiveArray};
43    use crate::compute::{Operator, compare};
44    use crate::validity::Validity;
45    use crate::{IntoArray, ToCanonical};
46
47    #[test]
48    fn test_compare_value() {
49        let masked = MaskedArray::try_new(
50            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
51            Validity::AllValid,
52        )
53        .unwrap();
54
55        let res = compare(
56            masked.as_ref(),
57            ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
58            Operator::Eq,
59        )
60        .unwrap();
61        let res = res.to_bool();
62        assert_eq!(
63            res.boolean_buffer().iter().collect::<Vec<_>>(),
64            vec![false, true, false]
65        );
66    }
67
68    #[test]
69    fn test_compare_non_eq() {
70        let masked = MaskedArray::try_new(
71            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
72            Validity::AllValid,
73        )
74        .unwrap();
75
76        let res = compare(
77            masked.as_ref(),
78            ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
79            Operator::Gt,
80        )
81        .unwrap();
82        let res = res.to_bool();
83        assert_eq!(
84            res.boolean_buffer().iter().collect::<Vec<_>>(),
85            vec![false, false, true]
86        );
87    }
88
89    #[test]
90    fn test_compare_nullable() {
91        // MaskedArray with nulls
92        let masked = MaskedArray::try_new(
93            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
94            Validity::from_iter([false, true, false]),
95        )
96        .unwrap();
97
98        let res = compare(
99            masked.as_ref(),
100            ConstantArray::new(Scalar::primitive(2i32, Nullability::Nullable), 3).as_ref(),
101            Operator::Eq,
102        )
103        .unwrap();
104        let res = res.to_bool();
105        assert_eq!(
106            res.boolean_buffer().iter().collect::<Vec<_>>(),
107            vec![false, true, false]
108        );
109        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
110        assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
111    }
112
113    #[test]
114    fn test_compare_with_null_rhs() {
115        // MaskedArray with some nulls
116        let masked = MaskedArray::try_new(
117            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
118            Validity::from_iter([true, true, false]),
119        )
120        .unwrap();
121
122        // RHS has a null value
123        let rhs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
124
125        let res = compare(masked.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126        let res = res.to_bool();
127        assert_eq!(
128            res.boolean_buffer().iter().collect::<Vec<_>>(),
129            vec![true, false, true]
130        );
131        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
132        // Validity is union of both: lhs=[T,T,F], rhs=[T,F,T] => result=[T,F,F]
133        assert_eq!(res.validity_mask(), Mask::from_iter([true, false, false]));
134    }
135}