vortex_array/compute/
mask.rs

1use arrow_array::BooleanArray;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3use vortex_mask::Mask;
4use vortex_scalar::Scalar;
5
6use crate::arrays::ConstantArray;
7use crate::arrow::{FromArrowArray, IntoArrowArray};
8use crate::compute::try_cast;
9use crate::encoding::Encoding;
10use crate::{Array, ArrayRef};
11
12pub trait MaskFn<A> {
13    /// Replace masked values with null in array.
14    fn mask(&self, array: A, mask: Mask) -> VortexResult<ArrayRef>;
15}
16
17impl<E: Encoding> MaskFn<&dyn Array> for E
18where
19    E: for<'a> MaskFn<&'a E::Array>,
20{
21    fn mask(&self, array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
22        let array_ref = array
23            .as_any()
24            .downcast_ref::<E::Array>()
25            .vortex_expect("Failed to downcast array");
26        MaskFn::mask(self, array_ref, mask)
27    }
28}
29
30/// Replace values with null where the mask is true.
31///
32/// The returned array is nullable but otherwise has the same dtype and length as `array`.
33///
34/// # Examples
35///
36/// ```
37/// use vortex_array::IntoArray;
38/// use vortex_array::arrays::{BoolArray, PrimitiveArray};
39/// use vortex_array::compute::{scalar_at, mask};
40/// use vortex_mask::Mask;
41/// use vortex_scalar::Scalar;
42///
43/// let array =
44///     PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]);
45/// let mask_array = Mask::try_from(
46///     &BoolArray::from_iter([true, false, false, false, true]),
47/// )
48/// .unwrap();
49///
50/// let masked = mask(&array, mask_array).unwrap();
51/// assert_eq!(masked.len(), 5);
52/// assert!(!masked.is_valid(0).unwrap());
53/// assert!(!masked.is_valid(1).unwrap());
54/// assert_eq!(scalar_at(&masked, 2).unwrap(), Scalar::from(Some(1)));
55/// assert!(!masked.is_valid(3).unwrap());
56/// assert!(!masked.is_valid(4).unwrap());
57/// ```
58///
59pub fn mask(array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
60    if mask.len() != array.len() {
61        vortex_bail!(
62            "mask.len() is {}, does not equal array.len() of {}",
63            mask.len(),
64            array.len()
65        );
66    }
67
68    let masked = if matches!(mask, Mask::AllFalse(_)) {
69        // Fast-path for empty mask
70        try_cast(array, &array.dtype().as_nullable())?
71    } else if matches!(mask, Mask::AllTrue(_)) {
72        // Fast-path for full mask.
73        ConstantArray::new(
74            Scalar::null(array.dtype().clone().as_nullable()),
75            array.len(),
76        )
77        .into_array()
78    } else {
79        mask_impl(array, mask)?
80    };
81
82    debug_assert_eq!(
83        masked.len(),
84        array.len(),
85        "Mask should not change length {}\n\n{:?}\n\n{:?}",
86        array.encoding(),
87        array,
88        masked
89    );
90    debug_assert_eq!(
91        masked.dtype(),
92        &array.dtype().as_nullable(),
93        "Mask dtype mismatch {} {} {} {}",
94        array.encoding(),
95        masked.dtype(),
96        array.dtype(),
97        array.dtype().as_nullable(),
98    );
99
100    Ok(masked)
101}
102
103fn mask_impl(array: &dyn Array, mask: Mask) -> VortexResult<ArrayRef> {
104    if let Some(mask_fn) = array.vtable().mask_fn() {
105        return mask_fn.mask(array, mask);
106    }
107
108    // Fallback: implement using Arrow kernels.
109    log::debug!("No mask implementation found for {}", array.encoding());
110
111    let array_ref = array.to_array().into_arrow_preferred()?;
112    let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
113
114    let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
115
116    Ok(ArrayRef::from_arrow(masked, true))
117}