vortex_array/arrays/masked/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::arrays::BoolArray;
10use crate::arrays::MaskedArray;
11use crate::arrays::MaskedVTable;
12use crate::canonical::ToCanonical;
13use crate::compute::CompareKernel;
14use crate::compute::CompareKernelAdapter;
15use crate::compute::Operator;
16use crate::compute::compare;
17use crate::register_kernel;
18use crate::vtable::ValidityHelper;
19
20impl CompareKernel for MaskedVTable {
21    fn compare(
22        &self,
23        lhs: &MaskedArray,
24        rhs: &dyn Array,
25        operator: Operator,
26    ) -> VortexResult<Option<ArrayRef>> {
27        // Compare the child arrays
28        let compare_result = compare(&lhs.child, rhs, operator)?;
29
30        // Get the boolean buffer from the comparison result
31        let bool_array = compare_result.to_bool();
32        let combined_validity = bool_array.validity().clone().and(lhs.validity().clone());
33
34        // Return a plain BoolArray with the combined validity
35        Ok(Some(
36            BoolArray::from_bit_buffer(bool_array.bit_buffer().clone(), combined_validity)
37                .into_array(),
38        ))
39    }
40}
41
42register_kernel!(CompareKernelAdapter(MaskedVTable).lift());
43
44#[cfg(test)]
45mod tests {
46    use vortex_dtype::Nullability;
47    use vortex_mask::Mask;
48    use vortex_scalar::Scalar;
49
50    use crate::IntoArray;
51    use crate::ToCanonical;
52    use crate::arrays::ConstantArray;
53    use crate::arrays::MaskedArray;
54    use crate::arrays::PrimitiveArray;
55    use crate::compute::Operator;
56    use crate::compute::compare;
57    use crate::validity::Validity;
58
59    #[test]
60    fn test_compare_value() {
61        let masked = MaskedArray::try_new(
62            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
63            Validity::AllValid,
64        )
65        .unwrap();
66
67        let res = compare(
68            masked.as_ref(),
69            ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
70            Operator::Eq,
71        )
72        .unwrap();
73        let res = res.to_bool();
74        assert_eq!(
75            res.bit_buffer().iter().collect::<Vec<_>>(),
76            vec![false, true, false]
77        );
78    }
79
80    #[test]
81    fn test_compare_non_eq() {
82        let masked = MaskedArray::try_new(
83            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
84            Validity::AllValid,
85        )
86        .unwrap();
87
88        let res = compare(
89            masked.as_ref(),
90            ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
91            Operator::Gt,
92        )
93        .unwrap();
94        let res = res.to_bool();
95        assert_eq!(
96            res.bit_buffer().iter().collect::<Vec<_>>(),
97            vec![false, false, true]
98        );
99    }
100
101    #[test]
102    fn test_compare_nullable() {
103        // MaskedArray with nulls
104        let masked = MaskedArray::try_new(
105            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
106            Validity::from_iter([false, true, false]),
107        )
108        .unwrap();
109
110        let res = compare(
111            masked.as_ref(),
112            ConstantArray::new(Scalar::primitive(2i32, Nullability::Nullable), 3).as_ref(),
113            Operator::Eq,
114        )
115        .unwrap();
116        let res = res.to_bool();
117        assert_eq!(
118            res.bit_buffer().iter().collect::<Vec<_>>(),
119            vec![false, true, false]
120        );
121        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
122        assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
123    }
124
125    #[test]
126    fn test_compare_with_null_rhs() {
127        // MaskedArray with some nulls
128        let masked = MaskedArray::try_new(
129            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
130            Validity::from_iter([true, true, false]),
131        )
132        .unwrap();
133
134        // RHS has a null value
135        let rhs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
136
137        let res = compare(masked.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
138        let res = res.to_bool();
139        assert_eq!(
140            res.bit_buffer().iter().collect::<Vec<_>>(),
141            vec![true, false, true]
142        );
143        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
144        // Validity is union of both: lhs=[T,T,F], rhs=[T,F,T] => result=[T,F,F]
145        assert_eq!(res.validity_mask(), Mask::from_iter([true, false, false]));
146    }
147}