vortex_array/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use arrow_array::BooleanArray;
8use vortex_dtype::DType;
9use vortex_error::VortexError;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_mask::Mask;
14use vortex_scalar::Scalar;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray;
19use crate::arrays::ConstantArray;
20use crate::arrow::FromArrowArray;
21use crate::arrow::IntoArrowArray;
22use crate::compute::ComputeFn;
23use crate::compute::ComputeFnVTable;
24use crate::compute::InvocationArgs;
25use crate::compute::Kernel;
26use crate::compute::Output;
27use crate::compute::cast;
28use crate::vtable::VTable;
29
30static MASK_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
31    let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn));
32    for kernel in inventory::iter::<MaskKernelRef> {
33        compute.register_kernel(kernel.0.clone());
34    }
35    compute
36});
37
38pub(crate) fn warm_up_vtable() -> usize {
39    MASK_FN.kernels().len()
40}
41
42/// Replace values with null where the mask is true.
43///
44/// The returned array is nullable but otherwise has the same dtype and length as `array`.
45///
46/// # Examples
47///
48/// ```
49/// use vortex_array::IntoArray;
50/// use vortex_array::arrays::{BoolArray, PrimitiveArray};
51/// use vortex_array::compute::{ mask};
52/// use vortex_mask::Mask;
53/// use vortex_scalar::Scalar;
54///
55/// let array =
56///     PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]);
57/// let mask_array = Mask::from_iter([true, false, false, false, true]);
58///
59/// let masked = mask(array.as_ref(), &mask_array).unwrap();
60/// assert_eq!(masked.len(), 5);
61/// assert!(!masked.is_valid(0));
62/// assert!(!masked.is_valid(1));
63/// assert_eq!(masked.scalar_at(2), Scalar::from(Some(1)));
64/// assert!(!masked.is_valid(3));
65/// assert!(!masked.is_valid(4));
66/// ```
67///
68pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
69    MASK_FN
70        .invoke(&InvocationArgs {
71            inputs: &[array.into(), mask.into()],
72            options: &(),
73        })?
74        .unwrap_array()
75}
76
77pub struct MaskKernelRef(ArcRef<dyn Kernel>);
78inventory::collect!(MaskKernelRef);
79
80pub trait MaskKernel: VTable {
81    /// Replace masked values with null in array.
82    fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef>;
83}
84
85#[derive(Debug)]
86pub struct MaskKernelAdapter<V: VTable>(pub V);
87
88impl<V: VTable + MaskKernel> MaskKernelAdapter<V> {
89    pub const fn lift(&'static self) -> MaskKernelRef {
90        MaskKernelRef(ArcRef::new_ref(self))
91    }
92}
93
94impl<V: VTable + MaskKernel> Kernel for MaskKernelAdapter<V> {
95    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
96        let inputs = MaskArgs::try_from(args)?;
97        let Some(array) = inputs.array.as_opt::<V>() else {
98            return Ok(None);
99        };
100        Ok(Some(V::mask(&self.0, array, inputs.mask)?.into()))
101    }
102}
103
104struct MaskFn;
105
106impl ComputeFnVTable for MaskFn {
107    fn invoke(
108        &self,
109        args: &InvocationArgs,
110        kernels: &[ArcRef<dyn Kernel>],
111    ) -> VortexResult<Output> {
112        let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
113
114        if matches!(mask, Mask::AllFalse(_)) {
115            // Fast-path for empty mask
116            return Ok(cast(array, &array.dtype().as_nullable())?.into());
117        }
118
119        if matches!(mask, Mask::AllTrue(_)) {
120            // Fast-path for full mask.
121            return Ok(ConstantArray::new(
122                Scalar::null(array.dtype().clone().as_nullable()),
123                array.len(),
124            )
125            .into_array()
126            .into());
127        }
128
129        // Do nothing if the array is already all nulls.
130        if array.all_invalid() {
131            return Ok(array.to_array().into());
132        }
133
134        for kernel in kernels {
135            if let Some(output) = kernel.invoke(args)? {
136                return Ok(output);
137            }
138        }
139        if let Some(output) = array.invoke(&MASK_FN, args)? {
140            return Ok(output);
141        }
142
143        // Fallback: implement using Arrow kernels.
144        log::debug!("No mask implementation found for {}", array.encoding_id());
145
146        let array_ref = array.to_array().into_arrow_preferred()?;
147        let mask = BooleanArray::new(mask.to_bit_buffer().into(), None);
148
149        let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
150
151        Ok(ArrayRef::from_arrow(masked.as_ref(), true).into())
152    }
153
154    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
155        let MaskArgs { array, .. } = MaskArgs::try_from(args)?;
156        Ok(array.dtype().as_nullable())
157    }
158
159    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
160        let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
161
162        if mask.len() != array.len() {
163            vortex_bail!(
164                "mask.len() is {}, does not equal array.len() of {}",
165                mask.len(),
166                array.len()
167            );
168        }
169
170        Ok(mask.len())
171    }
172
173    fn is_elementwise(&self) -> bool {
174        true
175    }
176}
177
178struct MaskArgs<'a> {
179    array: &'a dyn Array,
180    mask: &'a Mask,
181}
182
183impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> {
184    type Error = VortexError;
185
186    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
187        if value.inputs.len() != 2 {
188            vortex_bail!("Mask function requires 2 arguments");
189        }
190        let array = value.inputs[0]
191            .array()
192            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
193        let mask = value.inputs[1]
194            .mask()
195            .ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?;
196
197        Ok(MaskArgs { array, mask })
198    }
199}