vortex_array/arrays/masked/
execute.rs1use std::ops::BitAnd;
7use std::sync::Arc;
8
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::Canonical;
13use crate::IntoArray;
14use crate::arrays::BoolArray;
15use crate::arrays::DecimalArray;
16use crate::arrays::ExtensionArray;
17use crate::arrays::FixedSizeListArray;
18use crate::arrays::ListViewArray;
19use crate::arrays::MaskedArray;
20use crate::arrays::NullArray;
21use crate::arrays::PrimitiveArray;
22use crate::arrays::StructArray;
23use crate::arrays::VarBinViewArray;
24use crate::arrays::VariantArray;
25use crate::arrays::bool::BoolArrayExt;
26use crate::arrays::extension::ExtensionArrayExt;
27use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
28use crate::arrays::listview::ListViewArrayExt;
29use crate::arrays::struct_::StructArrayExt;
30use crate::arrays::variant::VariantArrayExt;
31use crate::dtype::Nullability;
32use crate::executor::ExecutionCtx;
33use crate::match_each_decimal_value_type;
34use crate::validity::Validity;
35
36pub fn mask_validity_canonical(
42 canonical: Canonical,
43 validity_mask: &Mask,
44 ctx: &mut ExecutionCtx,
45) -> VortexResult<Canonical> {
46 Ok(match canonical {
47 Canonical::Null(a) => Canonical::Null(mask_validity_null(a, validity_mask)),
48 Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity_mask, ctx)?),
49 Canonical::Primitive(a) => {
50 Canonical::Primitive(mask_validity_primitive(a, validity_mask, ctx)?)
51 }
52 Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity_mask, ctx)?),
53 Canonical::VarBinView(a) => {
54 Canonical::VarBinView(mask_validity_varbinview(a, validity_mask, ctx)?)
55 }
56 Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity_mask, ctx)?),
57 Canonical::FixedSizeList(a) => {
58 Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity_mask, ctx)?)
59 }
60 Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity_mask, ctx)?),
61 Canonical::Extension(a) => {
62 Canonical::Extension(mask_validity_extension(a, validity_mask, ctx)?)
63 }
64 Canonical::Variant(a) => Canonical::Variant(mask_validity_variant(a, validity_mask, ctx)?),
65 })
66}
67
68fn combine_validity(
69 validity: &Validity,
70 mask: &Mask,
71 len: usize,
72 ctx: &mut ExecutionCtx,
73) -> VortexResult<Validity> {
74 let current_mask = validity.execute_mask(len, ctx)?;
75 let combined = current_mask.bitand(mask);
76 Ok(Validity::from_mask(combined, Nullability::Nullable))
77}
78
79fn mask_validity_null(array: NullArray, _mask: &Mask) -> NullArray {
80 array
81}
82
83fn mask_validity_bool(
84 array: BoolArray,
85 mask: &Mask,
86 ctx: &mut ExecutionCtx,
87) -> VortexResult<BoolArray> {
88 let len = array.len();
89 let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
90 Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
91}
92
93fn mask_validity_primitive(
94 array: PrimitiveArray,
95 mask: &Mask,
96 ctx: &mut ExecutionCtx,
97) -> VortexResult<PrimitiveArray> {
98 let len = array.len();
99 let ptype = array.ptype();
100 let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
101 Ok(unsafe {
103 PrimitiveArray::new_unchecked_from_handle(
104 array.buffer_handle().clone(),
105 ptype,
106 new_validity,
107 )
108 })
109}
110
111fn mask_validity_decimal(
112 array: DecimalArray,
113 mask: &Mask,
114 ctx: &mut ExecutionCtx,
115) -> VortexResult<DecimalArray> {
116 let len = array.len();
117 let dec_dtype = array.decimal_dtype();
118 let values_type = array.values_type();
119 let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
120 Ok(match_each_decimal_value_type!(values_type, |T| {
122 let buffer = array.buffer::<T>();
123 unsafe { DecimalArray::new_unchecked(buffer, dec_dtype, new_validity) }
124 }))
125}
126
127fn mask_validity_varbinview(
129 array: VarBinViewArray,
130 mask: &Mask,
131 ctx: &mut ExecutionCtx,
132) -> VortexResult<VarBinViewArray> {
133 let len = array.len();
134 let dtype = array.dtype().as_nullable();
135 let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
136 Ok(unsafe {
138 VarBinViewArray::new_handle_unchecked(
139 array.views_handle().clone(),
140 Arc::clone(array.data_buffers()),
141 dtype,
142 new_validity,
143 )
144 })
145}
146
147fn mask_validity_listview(
148 array: ListViewArray,
149 mask: &Mask,
150 ctx: &mut ExecutionCtx,
151) -> VortexResult<ListViewArray> {
152 let len = array.len();
153 let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
154 Ok(unsafe {
156 ListViewArray::new_unchecked(
157 array.elements().clone(),
158 array.offsets().clone(),
159 array.sizes().clone(),
160 new_validity,
161 )
162 })
163}
164
165fn mask_validity_fixed_size_list(
166 array: FixedSizeListArray,
167 mask: &Mask,
168 ctx: &mut ExecutionCtx,
169) -> VortexResult<FixedSizeListArray> {
170 let len = array.len();
171 let list_size = array.list_size();
172 let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
173 Ok(unsafe {
175 FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
176 })
177}
178
179fn mask_validity_struct(
180 array: StructArray,
181 mask: &Mask,
182 ctx: &mut ExecutionCtx,
183) -> VortexResult<StructArray> {
184 let len = array.len();
185 let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
186 let fields = array.unmasked_fields();
187 let struct_fields = array.struct_fields();
188 Ok(unsafe { StructArray::new_unchecked(fields, struct_fields.clone(), len, new_validity) })
190}
191
192fn mask_validity_extension(
193 array: ExtensionArray,
194 mask: &Mask,
195 ctx: &mut ExecutionCtx,
196) -> VortexResult<ExtensionArray> {
197 let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
199 let masked_storage = mask_validity_canonical(storage, mask, ctx)?;
200 let masked_storage = masked_storage.into_array();
201 Ok(ExtensionArray::new(
202 array
203 .ext_dtype()
204 .with_nullability(masked_storage.dtype().nullability()),
205 masked_storage,
206 ))
207}
208
209fn mask_validity_variant(
210 array: VariantArray,
211 mask: &Mask,
212 ctx: &mut ExecutionCtx,
213) -> VortexResult<VariantArray> {
214 let child = array.child().clone();
215 let len = child.len();
216 let child_validity = child.validity()?;
217
218 match child_validity {
219 Validity::NonNullable | Validity::AllValid => {
220 let new_validity = Validity::from_mask(mask.clone(), Nullability::Nullable);
222 let masked_child = MaskedArray::try_new(child, new_validity)?;
223 Ok(VariantArray::new(masked_child.into_array()))
224 }
225 Validity::AllInvalid => {
226 Ok(array)
228 }
229 Validity::Array(_) => {
230 let combined = combine_validity(&child_validity, mask, len, ctx)?;
233 let new_child = child.with_slot(0, combined.to_array(len))?;
234 Ok(VariantArray::new(new_child))
235 }
236 }
237}