vortex_array/arrays/masked/compute/
compare.rs1use vortex_error::VortexResult;
5
6use crate::arrays::{BoolArray, MaskedArray, MaskedVTable};
7use crate::canonical::ToCanonical;
8use crate::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
9use crate::vtable::ValidityHelper;
10use crate::{Array, ArrayRef, IntoArray, register_kernel};
11
12impl CompareKernel for MaskedVTable {
13 fn compare(
14 &self,
15 lhs: &MaskedArray,
16 rhs: &dyn Array,
17 operator: Operator,
18 ) -> VortexResult<Option<ArrayRef>> {
19 let compare_result = compare(&lhs.child, rhs, operator)?;
21
22 let bool_array = compare_result.to_bool();
24 let combined_validity = bool_array.validity().clone().and(lhs.validity().clone());
25
26 Ok(Some(
28 BoolArray::from_bool_buffer(bool_array.boolean_buffer().clone(), combined_validity)
29 .into_array(),
30 ))
31 }
32}
33
34register_kernel!(CompareKernelAdapter(MaskedVTable).lift());
35
36#[cfg(test)]
37mod tests {
38 use vortex_dtype::Nullability;
39 use vortex_mask::Mask;
40 use vortex_scalar::Scalar;
41
42 use crate::arrays::{ConstantArray, MaskedArray, PrimitiveArray};
43 use crate::compute::{Operator, compare};
44 use crate::validity::Validity;
45 use crate::{IntoArray, ToCanonical};
46
47 #[test]
48 fn test_compare_value() {
49 let masked = MaskedArray::try_new(
50 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
51 Validity::AllValid,
52 )
53 .unwrap();
54
55 let res = compare(
56 masked.as_ref(),
57 ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
58 Operator::Eq,
59 )
60 .unwrap();
61 let res = res.to_bool();
62 assert_eq!(
63 res.boolean_buffer().iter().collect::<Vec<_>>(),
64 vec![false, true, false]
65 );
66 }
67
68 #[test]
69 fn test_compare_non_eq() {
70 let masked = MaskedArray::try_new(
71 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
72 Validity::AllValid,
73 )
74 .unwrap();
75
76 let res = compare(
77 masked.as_ref(),
78 ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
79 Operator::Gt,
80 )
81 .unwrap();
82 let res = res.to_bool();
83 assert_eq!(
84 res.boolean_buffer().iter().collect::<Vec<_>>(),
85 vec![false, false, true]
86 );
87 }
88
89 #[test]
90 fn test_compare_nullable() {
91 let masked = MaskedArray::try_new(
93 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
94 Validity::from_iter([false, true, false]),
95 )
96 .unwrap();
97
98 let res = compare(
99 masked.as_ref(),
100 ConstantArray::new(Scalar::primitive(2i32, Nullability::Nullable), 3).as_ref(),
101 Operator::Eq,
102 )
103 .unwrap();
104 let res = res.to_bool();
105 assert_eq!(
106 res.boolean_buffer().iter().collect::<Vec<_>>(),
107 vec![false, true, false]
108 );
109 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
110 assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
111 }
112
113 #[test]
114 fn test_compare_with_null_rhs() {
115 let masked = MaskedArray::try_new(
117 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
118 Validity::from_iter([true, true, false]),
119 )
120 .unwrap();
121
122 let rhs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
124
125 let res = compare(masked.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126 let res = res.to_bool();
127 assert_eq!(
128 res.boolean_buffer().iter().collect::<Vec<_>>(),
129 vec![true, false, true]
130 );
131 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
132 assert_eq!(res.validity_mask(), Mask::from_iter([true, false, false]));
134 }
135}