1use arrow_array::BooleanArray;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3use vortex_mask::Mask;
4use vortex_scalar::Scalar;
5
6use crate::arrays::ConstantArray;
7use crate::arrow::{FromArrowArray, IntoArrowArray};
8use crate::compute::try_cast;
9use crate::encoding::Encoding;
10use crate::{Array, ArrayRef};
11
12pub trait MaskFn<A> {
13 fn mask(&self, array: A, mask: Mask) -> VortexResult<ArrayRef>;
15}
16
17impl<E: Encoding> MaskFn<&dyn Array> for E
18where
19 E: for<'a> MaskFn<&'a E::Array>,
20{
21 fn mask(&self, array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
22 let array_ref = array
23 .as_any()
24 .downcast_ref::<E::Array>()
25 .vortex_expect("Failed to downcast array");
26 MaskFn::mask(self, array_ref, mask)
27 }
28}
29
30pub fn mask(array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
60 if mask.len() != array.len() {
61 vortex_bail!(
62 "mask.len() is {}, does not equal array.len() of {}",
63 mask.len(),
64 array.len()
65 );
66 }
67
68 let masked = if matches!(mask, Mask::AllFalse(_)) {
69 try_cast(array, &array.dtype().as_nullable())?
71 } else if matches!(mask, Mask::AllTrue(_)) {
72 ConstantArray::new(
74 Scalar::null(array.dtype().clone().as_nullable()),
75 array.len(),
76 )
77 .into_array()
78 } else {
79 mask_impl(array, mask)?
80 };
81
82 debug_assert_eq!(
83 masked.len(),
84 array.len(),
85 "Mask should not change length {}\n\n{:?}\n\n{:?}",
86 array.encoding(),
87 array,
88 masked
89 );
90 debug_assert_eq!(
91 masked.dtype(),
92 &array.dtype().as_nullable(),
93 "Mask dtype mismatch {} {} {} {}",
94 array.encoding(),
95 masked.dtype(),
96 array.dtype(),
97 array.dtype().as_nullable(),
98 );
99
100 Ok(masked)
101}
102
103fn mask_impl(array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
104 if let Some(mask_fn) = array.vtable().mask_fn() {
105 return mask_fn.mask(array, mask);
106 }
107
108 log::debug!("No mask implementation found for {}", array.encoding());
110
111 let array_ref = array.to_array().into_arrow_preferred()?;
112 let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
113
114 let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
115
116 Ok(ArrayRef::from_arrow(masked, true))
117}
118
119#[cfg(feature = "test-harness")]
120pub mod test_harness {
121 use vortex_mask::Mask;
122
123 use crate::Array;
124 use crate::arrays::BoolArray;
125 use crate::compute::{mask, scalar_at};
126
127 pub fn test_mask(array: &dyn Array) {
128 assert_eq!(array.len(), 5);
129 test_heterogenous_mask(array);
130 test_empty_mask(array);
131 test_full_mask(array);
132 }
133
134 #[allow(clippy::unwrap_used)]
135 fn test_heterogenous_mask(array: &dyn Array) {
136 let mask_array =
137 Mask::try_from(&BoolArray::from_iter([true, false, false, true, true])).unwrap();
138 let masked = mask(array, mask_array).unwrap();
139 assert_eq!(masked.len(), array.len());
140 assert!(!masked.is_valid(0).unwrap());
141 assert_eq!(
142 scalar_at(&masked, 1).unwrap(),
143 scalar_at(array, 1).unwrap().into_nullable()
144 );
145 assert_eq!(
146 scalar_at(&masked, 2).unwrap(),
147 scalar_at(array, 2).unwrap().into_nullable()
148 );
149 assert!(!masked.is_valid(3).unwrap());
150 assert!(!masked.is_valid(4).unwrap());
151 }
152
153 #[allow(clippy::unwrap_used)]
154 fn test_empty_mask(array: &dyn Array) {
155 let all_unmasked =
156 Mask::try_from(&BoolArray::from_iter([false, false, false, false, false])).unwrap();
157 let masked = mask(array, all_unmasked).unwrap();
158 assert_eq!(masked.len(), array.len());
159 assert_eq!(
160 scalar_at(&masked, 0).unwrap(),
161 scalar_at(array, 0).unwrap().into_nullable()
162 );
163 assert_eq!(
164 scalar_at(&masked, 1).unwrap(),
165 scalar_at(array, 1).unwrap().into_nullable()
166 );
167 assert_eq!(
168 scalar_at(&masked, 2).unwrap(),
169 scalar_at(array, 2).unwrap().into_nullable()
170 );
171 assert_eq!(
172 scalar_at(&masked, 3).unwrap(),
173 scalar_at(array, 3).unwrap().into_nullable()
174 );
175 assert_eq!(
176 scalar_at(&masked, 4).unwrap(),
177 scalar_at(array, 4).unwrap().into_nullable()
178 );
179 }
180
181 #[allow(clippy::unwrap_used)]
182 fn test_full_mask(array: &dyn Array) {
183 let all_masked =
184 Mask::try_from(&BoolArray::from_iter([true, true, true, true, true])).unwrap();
185 let masked = mask(array, all_masked).unwrap();
186 assert_eq!(masked.len(), array.len());
187 assert!(!masked.is_valid(0).unwrap());
188 assert!(!masked.is_valid(1).unwrap());
189 assert!(!masked.is_valid(2).unwrap());
190 assert!(!masked.is_valid(3).unwrap());
191 assert!(!masked.is_valid(4).unwrap());
192
193 let mask1 =
194 Mask::try_from(&BoolArray::from_iter([true, false, false, true, true])).unwrap();
195 let mask2 =
196 Mask::try_from(&BoolArray::from_iter([false, true, false, false, true])).unwrap();
197 let first = mask(array, mask1).unwrap();
198 let double_masked = mask(&first, mask2).unwrap();
199 assert_eq!(double_masked.len(), array.len());
200 assert!(!double_masked.is_valid(0).unwrap());
201 assert!(!double_masked.is_valid(1).unwrap());
202 assert_eq!(
203 scalar_at(&double_masked, 2).unwrap(),
204 scalar_at(array, 2).unwrap().into_nullable()
205 );
206 assert!(!double_masked.is_valid(3).unwrap());
207 assert!(!double_masked.is_valid(4).unwrap());
208 }
209}
210
211#[cfg(test)]
212mod test {
213 use super::test_harness::test_mask;
214 use crate::arrays::PrimitiveArray;
215
216 #[test]
217 fn test_mask_non_nullable_array() {
218 let non_nullable_array = PrimitiveArray::from_iter([1, 2, 3, 4, 5]);
219 test_mask(&non_nullable_array);
220 }
221}