vortex_array/compute/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use arrow_array::BooleanArray;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_mask::Mask;
11use vortex_scalar::Scalar;
12
13use crate::arrays::ConstantArray;
14use crate::arrow::{FromArrowArray, IntoArrowArray};
15use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
16use crate::vtable::VTable;
17use crate::{Array, ArrayRef, IntoArray};
18
19static MASK_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
20    let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn));
21    for kernel in inventory::iter::<MaskKernelRef> {
22        compute.register_kernel(kernel.0.clone());
23    }
24    compute
25});
26
27pub(crate) fn warm_up_vtable() -> usize {
28    MASK_FN.kernels().len()
29}
30
31/// Replace values with null where the mask is true.
32///
33/// The returned array is nullable but otherwise has the same dtype and length as `array`.
34///
35/// # Examples
36///
37/// ```
38/// use vortex_array::IntoArray;
39/// use vortex_array::arrays::{BoolArray, PrimitiveArray};
40/// use vortex_array::compute::{ mask};
41/// use vortex_mask::Mask;
42/// use vortex_scalar::Scalar;
43///
44/// let array =
45///     PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]);
46/// let mask_array = Mask::from_iter([true, false, false, false, true]);
47///
48/// let masked = mask(array.as_ref(), &mask_array).unwrap();
49/// assert_eq!(masked.len(), 5);
50/// assert!(!masked.is_valid(0));
51/// assert!(!masked.is_valid(1));
52/// assert_eq!(masked.scalar_at(2), Scalar::from(Some(1)));
53/// assert!(!masked.is_valid(3));
54/// assert!(!masked.is_valid(4));
55/// ```
56///
57pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
58    MASK_FN
59        .invoke(&InvocationArgs {
60            inputs: &[array.into(), mask.into()],
61            options: &(),
62        })?
63        .unwrap_array()
64}
65
66pub struct MaskKernelRef(ArcRef<dyn Kernel>);
67inventory::collect!(MaskKernelRef);
68
69pub trait MaskKernel: VTable {
70    /// Replace masked values with null in array.
71    fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef>;
72}
73
74#[derive(Debug)]
75pub struct MaskKernelAdapter<V: VTable>(pub V);
76
77impl<V: VTable + MaskKernel> MaskKernelAdapter<V> {
78    pub const fn lift(&'static self) -> MaskKernelRef {
79        MaskKernelRef(ArcRef::new_ref(self))
80    }
81}
82
83impl<V: VTable + MaskKernel> Kernel for MaskKernelAdapter<V> {
84    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
85        let inputs = MaskArgs::try_from(args)?;
86        let Some(array) = inputs.array.as_opt::<V>() else {
87            return Ok(None);
88        };
89        Ok(Some(V::mask(&self.0, array, inputs.mask)?.into()))
90    }
91}
92
93struct MaskFn;
94
95impl ComputeFnVTable for MaskFn {
96    fn invoke(
97        &self,
98        args: &InvocationArgs,
99        kernels: &[ArcRef<dyn Kernel>],
100    ) -> VortexResult<Output> {
101        let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
102
103        if matches!(mask, Mask::AllFalse(_)) {
104            // Fast-path for empty mask
105            return Ok(cast(array, &array.dtype().as_nullable())?.into());
106        }
107
108        if matches!(mask, Mask::AllTrue(_)) {
109            // Fast-path for full mask.
110            return Ok(ConstantArray::new(
111                Scalar::null(array.dtype().clone().as_nullable()),
112                array.len(),
113            )
114            .into_array()
115            .into());
116        }
117
118        // Do nothing if the array is already all nulls.
119        if array.all_invalid() {
120            return Ok(array.to_array().into());
121        }
122
123        for kernel in kernels {
124            if let Some(output) = kernel.invoke(args)? {
125                return Ok(output);
126            }
127        }
128        if let Some(output) = array.invoke(&MASK_FN, args)? {
129            return Ok(output);
130        }
131
132        // Fallback: implement using Arrow kernels.
133        log::debug!("No mask implementation found for {}", array.encoding_id());
134
135        let array_ref = array.to_array().into_arrow_preferred()?;
136        let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
137
138        let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
139
140        Ok(ArrayRef::from_arrow(masked.as_ref(), true).into())
141    }
142
143    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
144        let MaskArgs { array, .. } = MaskArgs::try_from(args)?;
145        Ok(array.dtype().as_nullable())
146    }
147
148    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
149        let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
150
151        if mask.len() != array.len() {
152            vortex_bail!(
153                "mask.len() is {}, does not equal array.len() of {}",
154                mask.len(),
155                array.len()
156            );
157        }
158
159        Ok(mask.len())
160    }
161
162    fn is_elementwise(&self) -> bool {
163        true
164    }
165}
166
167struct MaskArgs<'a> {
168    array: &'a dyn Array,
169    mask: &'a Mask,
170}
171
172impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> {
173    type Error = VortexError;
174
175    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
176        if value.inputs.len() != 2 {
177            vortex_bail!("Mask function requires 2 arguments");
178        }
179        let array = value.inputs[0]
180            .array()
181            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
182        let mask = value.inputs[1]
183            .mask()
184            .ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?;
185
186        Ok(MaskArgs { array, mask })
187    }
188}