vortex_array/compute/
mask.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use arrow_array::BooleanArray;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_mask::Mask;
11use vortex_scalar::Scalar;
12
13use crate::arrays::ConstantArray;
14use crate::arrow::{FromArrowArray, IntoArrowArray};
15use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
16use crate::vtable::VTable;
17use crate::{Array, ArrayRef, IntoArray};
18
19static MASK_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
20 let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn));
21 for kernel in inventory::iter::<MaskKernelRef> {
22 compute.register_kernel(kernel.0.clone());
23 }
24 compute
25});
26
27pub(crate) fn warm_up_vtable() -> usize {
28 MASK_FN.kernels().len()
29}
30
31pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
58 MASK_FN
59 .invoke(&InvocationArgs {
60 inputs: &[array.into(), mask.into()],
61 options: &(),
62 })?
63 .unwrap_array()
64}
65
66pub struct MaskKernelRef(ArcRef<dyn Kernel>);
67inventory::collect!(MaskKernelRef);
68
69pub trait MaskKernel: VTable {
70 fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef>;
72}
73
74#[derive(Debug)]
75pub struct MaskKernelAdapter<V: VTable>(pub V);
76
77impl<V: VTable + MaskKernel> MaskKernelAdapter<V> {
78 pub const fn lift(&'static self) -> MaskKernelRef {
79 MaskKernelRef(ArcRef::new_ref(self))
80 }
81}
82
83impl<V: VTable + MaskKernel> Kernel for MaskKernelAdapter<V> {
84 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
85 let inputs = MaskArgs::try_from(args)?;
86 let Some(array) = inputs.array.as_opt::<V>() else {
87 return Ok(None);
88 };
89 Ok(Some(V::mask(&self.0, array, inputs.mask)?.into()))
90 }
91}
92
93struct MaskFn;
94
95impl ComputeFnVTable for MaskFn {
96 fn invoke(
97 &self,
98 args: &InvocationArgs,
99 kernels: &[ArcRef<dyn Kernel>],
100 ) -> VortexResult<Output> {
101 let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
102
103 if matches!(mask, Mask::AllFalse(_)) {
104 return Ok(cast(array, &array.dtype().as_nullable())?.into());
106 }
107
108 if matches!(mask, Mask::AllTrue(_)) {
109 return Ok(ConstantArray::new(
111 Scalar::null(array.dtype().clone().as_nullable()),
112 array.len(),
113 )
114 .into_array()
115 .into());
116 }
117
118 for kernel in kernels {
119 if let Some(output) = kernel.invoke(args)? {
120 return Ok(output);
121 }
122 }
123 if let Some(output) = array.invoke(&MASK_FN, args)? {
124 return Ok(output);
125 }
126
127 log::debug!("No mask implementation found for {}", array.encoding_id());
129
130 let array_ref = array.to_array().into_arrow_preferred()?;
131 let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
132
133 let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
134
135 Ok(ArrayRef::from_arrow(masked.as_ref(), true).into())
136 }
137
138 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
139 let MaskArgs { array, .. } = MaskArgs::try_from(args)?;
140 Ok(array.dtype().as_nullable())
141 }
142
143 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
144 let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
145
146 if mask.len() != array.len() {
147 vortex_bail!(
148 "mask.len() is {}, does not equal array.len() of {}",
149 mask.len(),
150 array.len()
151 );
152 }
153
154 Ok(mask.len())
155 }
156
157 fn is_elementwise(&self) -> bool {
158 true
159 }
160}
161
162struct MaskArgs<'a> {
163 array: &'a dyn Array,
164 mask: &'a Mask,
165}
166
167impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> {
168 type Error = VortexError;
169
170 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
171 if value.inputs.len() != 2 {
172 vortex_bail!("Mask function requires 2 arguments");
173 }
174 let array = value.inputs[0]
175 .array()
176 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
177 let mask = value.inputs[1]
178 .mask()
179 .ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?;
180
181 Ok(MaskArgs { array, mask })
182 }
183}