vortex_array/compute/
mask.rs

1use std::sync::LazyLock;
2
3use arrow_array::BooleanArray;
4use vortex_dtype::DType;
5use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arcref::ArcRef;
10use crate::arrays::ConstantArray;
11use crate::arrow::{FromArrowArray, IntoArrowArray};
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
13use crate::encoding::Encoding;
14use crate::{Array, ArrayRef};
15
16/// Replace values with null where the mask is true.
17///
18/// The returned array is nullable but otherwise has the same dtype and length as `array`.
19///
20/// # Examples
21///
22/// ```
23/// use vortex_array::IntoArray;
24/// use vortex_array::arrays::{BoolArray, PrimitiveArray};
25/// use vortex_array::compute::{scalar_at, mask};
26/// use vortex_mask::Mask;
27/// use vortex_scalar::Scalar;
28///
29/// let array =
30///     PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]);
31/// let mask_array = Mask::try_from(
32///     &BoolArray::from_iter([true, false, false, false, true]),
33/// )
34/// .unwrap();
35///
36/// let masked = mask(&array, &mask_array).unwrap();
37/// assert_eq!(masked.len(), 5);
38/// assert!(!masked.is_valid(0).unwrap());
39/// assert!(!masked.is_valid(1).unwrap());
40/// assert_eq!(scalar_at(&masked, 2).unwrap(), Scalar::from(Some(1)));
41/// assert!(!masked.is_valid(3).unwrap());
42/// assert!(!masked.is_valid(4).unwrap());
43/// ```
44///
45pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
46    MASK_FN
47        .invoke(&InvocationArgs {
48            inputs: &[array.into(), mask.into()],
49            options: &(),
50        })?
51        .unwrap_array()
52}
53
54pub struct MaskKernelRef(ArcRef<dyn Kernel>);
55inventory::collect!(MaskKernelRef);
56
57pub trait MaskKernel: Encoding {
58    /// Replace masked values with null in array.
59    fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef>;
60}
61
62#[derive(Debug)]
63pub struct MaskKernelAdapter<E: Encoding>(pub E);
64
65impl<E: Encoding + MaskKernel> MaskKernelAdapter<E> {
66    pub const fn lift(&'static self) -> MaskKernelRef {
67        MaskKernelRef(ArcRef::new_ref(self))
68    }
69}
70
71impl<E: Encoding + MaskKernel> Kernel for MaskKernelAdapter<E> {
72    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
73        let inputs = MaskArgs::try_from(args)?;
74        let Some(array) = inputs.array.as_any().downcast_ref::<E::Array>() else {
75            return Ok(None);
76        };
77        Ok(Some(E::mask(&self.0, array, inputs.mask)?.into()))
78    }
79}
80
81pub static MASK_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
82    let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn));
83    for kernel in inventory::iter::<MaskKernelRef> {
84        compute.register_kernel(kernel.0.clone());
85    }
86    compute
87});
88
89struct MaskFn;
90
91impl ComputeFnVTable for MaskFn {
92    fn invoke(
93        &self,
94        args: &InvocationArgs,
95        kernels: &[ArcRef<dyn Kernel>],
96    ) -> VortexResult<Output> {
97        let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
98
99        if matches!(mask, Mask::AllFalse(_)) {
100            // Fast-path for empty mask
101            return Ok(cast(array, &array.dtype().as_nullable())?.into());
102        }
103
104        if matches!(mask, Mask::AllTrue(_)) {
105            // Fast-path for full mask.
106            return Ok(ConstantArray::new(
107                Scalar::null(array.dtype().clone().as_nullable()),
108                array.len(),
109            )
110            .into_array()
111            .into());
112        }
113
114        for kernel in kernels {
115            if let Some(output) = kernel.invoke(args)? {
116                return Ok(output);
117            }
118        }
119        if let Some(output) = array.invoke(&MASK_FN, args)? {
120            return Ok(output);
121        }
122
123        // Fallback: implement using Arrow kernels.
124        log::debug!("No mask implementation found for {}", array.encoding());
125
126        let array_ref = array.to_array().into_arrow_preferred()?;
127        let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
128
129        let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
130
131        Ok(ArrayRef::from_arrow(masked, true).into())
132    }
133
134    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
135        let MaskArgs { array, .. } = MaskArgs::try_from(args)?;
136        Ok(array.dtype().as_nullable())
137    }
138
139    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
140        let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
141
142        if mask.len() != array.len() {
143            vortex_bail!(
144                "mask.len() is {}, does not equal array.len() of {}",
145                mask.len(),
146                array.len()
147            );
148        }
149
150        Ok(mask.len())
151    }
152
153    fn is_elementwise(&self) -> bool {
154        true
155    }
156}
157
158struct MaskArgs<'a> {
159    array: &'a dyn Array,
160    mask: &'a Mask,
161}
162
163impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> {
164    type Error = VortexError;
165
166    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
167        if value.inputs.len() != 2 {
168            vortex_bail!("Mask function requires 2 arguments");
169        }
170        let array = value.inputs[0]
171            .array()
172            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
173        let mask = value.inputs[1]
174            .mask()
175            .ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?;
176
177        Ok(MaskArgs { array, mask })
178    }
179}