vortex_array/compute/
mask.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use arrow_array::BooleanArray;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_mask::Mask;
11use vortex_scalar::Scalar;
12
13use crate::arrays::ConstantArray;
14use crate::arrow::{FromArrowArray, IntoArrowArray};
15use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
16use crate::vtable::VTable;
17use crate::{Array, ArrayRef, IntoArray};
18
19pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
49 MASK_FN
50 .invoke(&InvocationArgs {
51 inputs: &[array.into(), mask.into()],
52 options: &(),
53 })?
54 .unwrap_array()
55}
56
57pub struct MaskKernelRef(ArcRef<dyn Kernel>);
58inventory::collect!(MaskKernelRef);
59
60pub trait MaskKernel: VTable {
61 fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef>;
63}
64
65#[derive(Debug)]
66pub struct MaskKernelAdapter<V: VTable>(pub V);
67
68impl<V: VTable + MaskKernel> MaskKernelAdapter<V> {
69 pub const fn lift(&'static self) -> MaskKernelRef {
70 MaskKernelRef(ArcRef::new_ref(self))
71 }
72}
73
74impl<V: VTable + MaskKernel> Kernel for MaskKernelAdapter<V> {
75 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
76 let inputs = MaskArgs::try_from(args)?;
77 let Some(array) = inputs.array.as_opt::<V>() else {
78 return Ok(None);
79 };
80 Ok(Some(V::mask(&self.0, array, inputs.mask)?.into()))
81 }
82}
83
84pub static MASK_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
85 let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn));
86 for kernel in inventory::iter::<MaskKernelRef> {
87 compute.register_kernel(kernel.0.clone());
88 }
89 compute
90});
91
92struct MaskFn;
93
94impl ComputeFnVTable for MaskFn {
95 fn invoke(
96 &self,
97 args: &InvocationArgs,
98 kernels: &[ArcRef<dyn Kernel>],
99 ) -> VortexResult<Output> {
100 let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
101
102 if matches!(mask, Mask::AllFalse(_)) {
103 return Ok(cast(array, &array.dtype().as_nullable())?.into());
105 }
106
107 if matches!(mask, Mask::AllTrue(_)) {
108 return Ok(ConstantArray::new(
110 Scalar::null(array.dtype().clone().as_nullable()),
111 array.len(),
112 )
113 .into_array()
114 .into());
115 }
116
117 for kernel in kernels {
118 if let Some(output) = kernel.invoke(args)? {
119 return Ok(output);
120 }
121 }
122 if let Some(output) = array.invoke(&MASK_FN, args)? {
123 return Ok(output);
124 }
125
126 log::debug!("No mask implementation found for {}", array.encoding_id());
128
129 let array_ref = array.to_array().into_arrow_preferred()?;
130 let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
131
132 let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
133
134 Ok(ArrayRef::from_arrow(masked.as_ref(), true).into())
135 }
136
137 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
138 let MaskArgs { array, .. } = MaskArgs::try_from(args)?;
139 Ok(array.dtype().as_nullable())
140 }
141
142 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
143 let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
144
145 if mask.len() != array.len() {
146 vortex_bail!(
147 "mask.len() is {}, does not equal array.len() of {}",
148 mask.len(),
149 array.len()
150 );
151 }
152
153 Ok(mask.len())
154 }
155
156 fn is_elementwise(&self) -> bool {
157 true
158 }
159}
160
161struct MaskArgs<'a> {
162 array: &'a dyn Array,
163 mask: &'a Mask,
164}
165
166impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> {
167 type Error = VortexError;
168
169 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
170 if value.inputs.len() != 2 {
171 vortex_bail!("Mask function requires 2 arguments");
172 }
173 let array = value.inputs[0]
174 .array()
175 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
176 let mask = value.inputs[1]
177 .mask()
178 .ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?;
179
180 Ok(MaskArgs { array, mask })
181 }
182}