vortex_array/compute/
mask.rs

1use arrow_array::BooleanArray;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3use vortex_mask::Mask;
4use vortex_scalar::Scalar;
5
6use crate::arrays::ConstantArray;
7use crate::arrow::{FromArrowArray, IntoArrowArray};
8use crate::compute::try_cast;
9use crate::encoding::Encoding;
10use crate::{Array, ArrayRef};
11
12pub trait MaskFn<A> {
13    /// Replace masked values with null in array.
14    fn mask(&self, array: A, mask: Mask) -> VortexResult<ArrayRef>;
15}
16
17impl<E: Encoding> MaskFn<&dyn Array> for E
18where
19    E: for<'a> MaskFn<&'a E::Array>,
20{
21    fn mask(&self, array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
22        let array_ref = array
23            .as_any()
24            .downcast_ref::<E::Array>()
25            .vortex_expect("Failed to downcast array");
26        MaskFn::mask(self, array_ref, mask)
27    }
28}
29
30/// Replace values with null where the mask is true.
31///
32/// The returned array is nullable but otherwise has the same dtype and length as `array`.
33///
34/// # Examples
35///
36/// ```
37/// use vortex_array::IntoArray;
38/// use vortex_array::arrays::{BoolArray, PrimitiveArray};
39/// use vortex_array::compute::{scalar_at, mask};
40/// use vortex_mask::Mask;
41/// use vortex_scalar::Scalar;
42///
43/// let array =
44///     PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]);
45/// let mask_array = Mask::try_from(
46///     &BoolArray::from_iter([true, false, false, false, true]),
47/// )
48/// .unwrap();
49///
50/// let masked = mask(&array, mask_array).unwrap();
51/// assert_eq!(masked.len(), 5);
52/// assert!(!masked.is_valid(0).unwrap());
53/// assert!(!masked.is_valid(1).unwrap());
54/// assert_eq!(scalar_at(&masked, 2).unwrap(), Scalar::from(Some(1)));
55/// assert!(!masked.is_valid(3).unwrap());
56/// assert!(!masked.is_valid(4).unwrap());
57/// ```
58///
59pub fn mask(array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
60    if mask.len() != array.len() {
61        vortex_bail!(
62            "mask.len() is {}, does not equal array.len() of {}",
63            mask.len(),
64            array.len()
65        );
66    }
67
68    let masked = if matches!(mask, Mask::AllFalse(_)) {
69        // Fast-path for empty mask
70        try_cast(array, &array.dtype().as_nullable())?
71    } else if matches!(mask, Mask::AllTrue(_)) {
72        // Fast-path for full mask.
73        ConstantArray::new(
74            Scalar::null(array.dtype().clone().as_nullable()),
75            array.len(),
76        )
77        .into_array()
78    } else {
79        mask_impl(array, mask)?
80    };
81
82    debug_assert_eq!(
83        masked.len(),
84        array.len(),
85        "Mask should not change length {}\n\n{:?}\n\n{:?}",
86        array.encoding(),
87        array,
88        masked
89    );
90    debug_assert_eq!(
91        masked.dtype(),
92        &array.dtype().as_nullable(),
93        "Mask dtype mismatch {} {} {} {}",
94        array.encoding(),
95        masked.dtype(),
96        array.dtype(),
97        array.dtype().as_nullable(),
98    );
99
100    Ok(masked)
101}
102
103fn mask_impl(array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
104    if let Some(mask_fn) = array.vtable().mask_fn() {
105        return mask_fn.mask(array, mask);
106    }
107
108    // Fallback: implement using Arrow kernels.
109    log::debug!("No mask implementation found for {}", array.encoding());
110
111    let array_ref = array.to_array().into_arrow_preferred()?;
112    let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
113
114    let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
115
116    Ok(ArrayRef::from_arrow(masked, true))
117}
118
119#[cfg(feature = "test-harness")]
120pub mod test_harness {
121    use vortex_mask::Mask;
122
123    use crate::Array;
124    use crate::arrays::BoolArray;
125    use crate::compute::{mask, scalar_at};
126
127    pub fn test_mask(array: &dyn Array) {
128        assert_eq!(array.len(), 5);
129        test_heterogenous_mask(array);
130        test_empty_mask(array);
131        test_full_mask(array);
132    }
133
134    #[allow(clippy::unwrap_used)]
135    fn test_heterogenous_mask(array: &dyn Array) {
136        let mask_array =
137            Mask::try_from(&BoolArray::from_iter([true, false, false, true, true])).unwrap();
138        let masked = mask(array, mask_array).unwrap();
139        assert_eq!(masked.len(), array.len());
140        assert!(!masked.is_valid(0).unwrap());
141        assert_eq!(
142            scalar_at(&masked, 1).unwrap(),
143            scalar_at(array, 1).unwrap().into_nullable()
144        );
145        assert_eq!(
146            scalar_at(&masked, 2).unwrap(),
147            scalar_at(array, 2).unwrap().into_nullable()
148        );
149        assert!(!masked.is_valid(3).unwrap());
150        assert!(!masked.is_valid(4).unwrap());
151    }
152
153    #[allow(clippy::unwrap_used)]
154    fn test_empty_mask(array: &dyn Array) {
155        let all_unmasked =
156            Mask::try_from(&BoolArray::from_iter([false, false, false, false, false])).unwrap();
157        let masked = mask(array, all_unmasked).unwrap();
158        assert_eq!(masked.len(), array.len());
159        assert_eq!(
160            scalar_at(&masked, 0).unwrap(),
161            scalar_at(array, 0).unwrap().into_nullable()
162        );
163        assert_eq!(
164            scalar_at(&masked, 1).unwrap(),
165            scalar_at(array, 1).unwrap().into_nullable()
166        );
167        assert_eq!(
168            scalar_at(&masked, 2).unwrap(),
169            scalar_at(array, 2).unwrap().into_nullable()
170        );
171        assert_eq!(
172            scalar_at(&masked, 3).unwrap(),
173            scalar_at(array, 3).unwrap().into_nullable()
174        );
175        assert_eq!(
176            scalar_at(&masked, 4).unwrap(),
177            scalar_at(array, 4).unwrap().into_nullable()
178        );
179    }
180
181    #[allow(clippy::unwrap_used)]
182    fn test_full_mask(array: &dyn Array) {
183        let all_masked =
184            Mask::try_from(&BoolArray::from_iter([true, true, true, true, true])).unwrap();
185        let masked = mask(array, all_masked).unwrap();
186        assert_eq!(masked.len(), array.len());
187        assert!(!masked.is_valid(0).unwrap());
188        assert!(!masked.is_valid(1).unwrap());
189        assert!(!masked.is_valid(2).unwrap());
190        assert!(!masked.is_valid(3).unwrap());
191        assert!(!masked.is_valid(4).unwrap());
192
193        let mask1 =
194            Mask::try_from(&BoolArray::from_iter([true, false, false, true, true])).unwrap();
195        let mask2 =
196            Mask::try_from(&BoolArray::from_iter([false, true, false, false, true])).unwrap();
197        let first = mask(array, mask1).unwrap();
198        let double_masked = mask(&first, mask2).unwrap();
199        assert_eq!(double_masked.len(), array.len());
200        assert!(!double_masked.is_valid(0).unwrap());
201        assert!(!double_masked.is_valid(1).unwrap());
202        assert_eq!(
203            scalar_at(&double_masked, 2).unwrap(),
204            scalar_at(array, 2).unwrap().into_nullable()
205        );
206        assert!(!double_masked.is_valid(3).unwrap());
207        assert!(!double_masked.is_valid(4).unwrap());
208    }
209}
210
211#[cfg(test)]
212mod test {
213    use super::test_harness::test_mask;
214    use crate::arrays::PrimitiveArray;
215
216    #[test]
217    fn test_mask_non_nullable_array() {
218        let non_nullable_array = PrimitiveArray::from_iter([1, 2, 3, 4, 5]);
219        test_mask(&non_nullable_array);
220    }
221}