vortex_array/arrays/masked/compute/
compare.rs1use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::arrays::BoolArray;
10use crate::arrays::MaskedArray;
11use crate::arrays::MaskedVTable;
12use crate::canonical::ToCanonical;
13use crate::compute::CompareKernel;
14use crate::compute::CompareKernelAdapter;
15use crate::compute::Operator;
16use crate::compute::compare;
17use crate::register_kernel;
18use crate::vtable::ValidityHelper;
19
20impl CompareKernel for MaskedVTable {
21 fn compare(
22 &self,
23 lhs: &MaskedArray,
24 rhs: &dyn Array,
25 operator: Operator,
26 ) -> VortexResult<Option<ArrayRef>> {
27 let compare_result = compare(&lhs.child, rhs, operator)?;
29
30 let bool_array = compare_result.to_bool();
32 let combined_validity = bool_array.validity().clone().and(lhs.validity().clone());
33
34 Ok(Some(
36 BoolArray::from_bit_buffer(bool_array.bit_buffer().clone(), combined_validity)
37 .into_array(),
38 ))
39 }
40}
41
42register_kernel!(CompareKernelAdapter(MaskedVTable).lift());
43
44#[cfg(test)]
45mod tests {
46 use vortex_dtype::Nullability;
47 use vortex_mask::Mask;
48 use vortex_scalar::Scalar;
49
50 use crate::IntoArray;
51 use crate::ToCanonical;
52 use crate::arrays::ConstantArray;
53 use crate::arrays::MaskedArray;
54 use crate::arrays::PrimitiveArray;
55 use crate::compute::Operator;
56 use crate::compute::compare;
57 use crate::validity::Validity;
58
59 #[test]
60 fn test_compare_value() {
61 let masked = MaskedArray::try_new(
62 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
63 Validity::AllValid,
64 )
65 .unwrap();
66
67 let res = compare(
68 masked.as_ref(),
69 ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
70 Operator::Eq,
71 )
72 .unwrap();
73 let res = res.to_bool();
74 assert_eq!(
75 res.bit_buffer().iter().collect::<Vec<_>>(),
76 vec![false, true, false]
77 );
78 }
79
80 #[test]
81 fn test_compare_non_eq() {
82 let masked = MaskedArray::try_new(
83 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
84 Validity::AllValid,
85 )
86 .unwrap();
87
88 let res = compare(
89 masked.as_ref(),
90 ConstantArray::new(Scalar::from(2i32), 3).as_ref(),
91 Operator::Gt,
92 )
93 .unwrap();
94 let res = res.to_bool();
95 assert_eq!(
96 res.bit_buffer().iter().collect::<Vec<_>>(),
97 vec![false, false, true]
98 );
99 }
100
101 #[test]
102 fn test_compare_nullable() {
103 let masked = MaskedArray::try_new(
105 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
106 Validity::from_iter([false, true, false]),
107 )
108 .unwrap();
109
110 let res = compare(
111 masked.as_ref(),
112 ConstantArray::new(Scalar::primitive(2i32, Nullability::Nullable), 3).as_ref(),
113 Operator::Eq,
114 )
115 .unwrap();
116 let res = res.to_bool();
117 assert_eq!(
118 res.bit_buffer().iter().collect::<Vec<_>>(),
119 vec![false, true, false]
120 );
121 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
122 assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
123 }
124
125 #[test]
126 fn test_compare_with_null_rhs() {
127 let masked = MaskedArray::try_new(
129 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
130 Validity::from_iter([true, true, false]),
131 )
132 .unwrap();
133
134 let rhs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
136
137 let res = compare(masked.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
138 let res = res.to_bool();
139 assert_eq!(
140 res.bit_buffer().iter().collect::<Vec<_>>(),
141 vec![true, false, true]
142 );
143 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
144 assert_eq!(res.validity_mask(), Mask::from_iter([true, false, false]));
146 }
147}