Skip to main content

vortex_array/arrays/masked/
execute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Execution logic for MaskedArray - applies a validity mask to canonical arrays.
5
6use std::ops::BitAnd;
7use std::sync::Arc;
8
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::Canonical;
13use crate::IntoArray;
14use crate::arrays::BoolArray;
15use crate::arrays::DecimalArray;
16use crate::arrays::ExtensionArray;
17use crate::arrays::FixedSizeListArray;
18use crate::arrays::ListViewArray;
19use crate::arrays::MaskedArray;
20use crate::arrays::NullArray;
21use crate::arrays::PrimitiveArray;
22use crate::arrays::StructArray;
23use crate::arrays::VarBinViewArray;
24use crate::arrays::VariantArray;
25use crate::arrays::bool::BoolArrayExt;
26use crate::arrays::extension::ExtensionArrayExt;
27use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
28use crate::arrays::listview::ListViewArrayExt;
29use crate::arrays::struct_::StructArrayExt;
30use crate::arrays::variant::VariantArrayExt;
31use crate::dtype::Nullability;
32use crate::executor::ExecutionCtx;
33use crate::match_each_decimal_value_type;
34use crate::validity::Validity;
35
36/// TODO: replace usage of compute fn.
37/// Apply a validity mask to a canonical array, ANDing with existing validity.
38///
39/// This is the core operation for MaskedArray execution - it intersects the child's
40/// validity with the provided mask, marking additional positions as invalid.
41pub fn mask_validity_canonical(
42    canonical: Canonical,
43    validity_mask: &Mask,
44    ctx: &mut ExecutionCtx,
45) -> VortexResult<Canonical> {
46    Ok(match canonical {
47        Canonical::Null(a) => Canonical::Null(mask_validity_null(a, validity_mask)),
48        Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity_mask, ctx)?),
49        Canonical::Primitive(a) => {
50            Canonical::Primitive(mask_validity_primitive(a, validity_mask, ctx)?)
51        }
52        Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity_mask, ctx)?),
53        Canonical::VarBinView(a) => {
54            Canonical::VarBinView(mask_validity_varbinview(a, validity_mask, ctx)?)
55        }
56        Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity_mask, ctx)?),
57        Canonical::FixedSizeList(a) => {
58            Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity_mask, ctx)?)
59        }
60        Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity_mask, ctx)?),
61        Canonical::Extension(a) => {
62            Canonical::Extension(mask_validity_extension(a, validity_mask, ctx)?)
63        }
64        Canonical::Variant(a) => Canonical::Variant(mask_validity_variant(a, validity_mask, ctx)?),
65    })
66}
67
68fn combine_validity(
69    validity: &Validity,
70    mask: &Mask,
71    len: usize,
72    ctx: &mut ExecutionCtx,
73) -> VortexResult<Validity> {
74    let current_mask = validity.execute_mask(len, ctx)?;
75    let combined = current_mask.bitand(mask);
76    Ok(Validity::from_mask(combined, Nullability::Nullable))
77}
78
79fn mask_validity_null(array: NullArray, _mask: &Mask) -> NullArray {
80    array
81}
82
83fn mask_validity_bool(
84    array: BoolArray,
85    mask: &Mask,
86    ctx: &mut ExecutionCtx,
87) -> VortexResult<BoolArray> {
88    let len = array.len();
89    let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
90    Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
91}
92
93fn mask_validity_primitive(
94    array: PrimitiveArray,
95    mask: &Mask,
96    ctx: &mut ExecutionCtx,
97) -> VortexResult<PrimitiveArray> {
98    let len = array.len();
99    let ptype = array.ptype();
100    let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
101    // SAFETY: validity has same length as values
102    Ok(unsafe {
103        PrimitiveArray::new_unchecked_from_handle(
104            array.buffer_handle().clone(),
105            ptype,
106            new_validity,
107        )
108    })
109}
110
111fn mask_validity_decimal(
112    array: DecimalArray,
113    mask: &Mask,
114    ctx: &mut ExecutionCtx,
115) -> VortexResult<DecimalArray> {
116    let len = array.len();
117    let dec_dtype = array.decimal_dtype();
118    let values_type = array.values_type();
119    let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
120    // SAFETY: We're only changing validity, not the data structure
121    Ok(match_each_decimal_value_type!(values_type, |T| {
122        let buffer = array.buffer::<T>();
123        unsafe { DecimalArray::new_unchecked(buffer, dec_dtype, new_validity) }
124    }))
125}
126
127/// Mask validity for VarBinViewArray.
128fn mask_validity_varbinview(
129    array: VarBinViewArray,
130    mask: &Mask,
131    ctx: &mut ExecutionCtx,
132) -> VortexResult<VarBinViewArray> {
133    let len = array.len();
134    let dtype = array.dtype().as_nullable();
135    let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
136    // SAFETY: We're only changing validity, not the data structure
137    Ok(unsafe {
138        VarBinViewArray::new_handle_unchecked(
139            array.views_handle().clone(),
140            Arc::clone(array.data_buffers()),
141            dtype,
142            new_validity,
143        )
144    })
145}
146
147fn mask_validity_listview(
148    array: ListViewArray,
149    mask: &Mask,
150    ctx: &mut ExecutionCtx,
151) -> VortexResult<ListViewArray> {
152    let len = array.len();
153    let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
154    // SAFETY: We're only changing validity, not the data structure
155    Ok(unsafe {
156        ListViewArray::new_unchecked(
157            array.elements().clone(),
158            array.offsets().clone(),
159            array.sizes().clone(),
160            new_validity,
161        )
162    })
163}
164
165fn mask_validity_fixed_size_list(
166    array: FixedSizeListArray,
167    mask: &Mask,
168    ctx: &mut ExecutionCtx,
169) -> VortexResult<FixedSizeListArray> {
170    let len = array.len();
171    let list_size = array.list_size();
172    let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
173    // SAFETY: We're only changing validity, not the data structure
174    Ok(unsafe {
175        FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
176    })
177}
178
179fn mask_validity_struct(
180    array: StructArray,
181    mask: &Mask,
182    ctx: &mut ExecutionCtx,
183) -> VortexResult<StructArray> {
184    let len = array.len();
185    let new_validity = combine_validity(&array.validity()?, mask, len, ctx)?;
186    let fields = array.unmasked_fields();
187    let struct_fields = array.struct_fields();
188    // SAFETY: We're only changing validity, not the data structure
189    Ok(unsafe { StructArray::new_unchecked(fields, struct_fields.clone(), len, new_validity) })
190}
191
192fn mask_validity_extension(
193    array: ExtensionArray,
194    mask: &Mask,
195    ctx: &mut ExecutionCtx,
196) -> VortexResult<ExtensionArray> {
197    // For extension arrays, we need to mask the underlying storage
198    let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
199    let masked_storage = mask_validity_canonical(storage, mask, ctx)?;
200    let masked_storage = masked_storage.into_array();
201    Ok(ExtensionArray::new(
202        array
203            .ext_dtype()
204            .with_nullability(masked_storage.dtype().nullability()),
205        masked_storage,
206    ))
207}
208
209fn mask_validity_variant(
210    array: VariantArray,
211    mask: &Mask,
212    ctx: &mut ExecutionCtx,
213) -> VortexResult<VariantArray> {
214    let child = array.child().clone();
215    let len = child.len();
216    let child_validity = child.validity()?;
217
218    match child_validity {
219        Validity::NonNullable | Validity::AllValid => {
220            // Child has no nulls — wrap in MaskedArray to apply the mask.
221            let new_validity = Validity::from_mask(mask.clone(), Nullability::Nullable);
222            let masked_child = MaskedArray::try_new(child, new_validity)?;
223            Ok(VariantArray::new(masked_child.into_array()))
224        }
225        Validity::AllInvalid => {
226            // Already all-null, ANDing with any mask is still all-null.
227            Ok(array)
228        }
229        Validity::Array(_) => {
230            // Child has an array-backed validity stored as its first child.
231            // Combine with the mask and replace that child via with_children.
232            let combined = combine_validity(&child_validity, mask, len, ctx)?;
233            let new_child = child.with_slot(0, combined.to_array(len))?;
234            Ok(VariantArray::new(new_child))
235        }
236    }
237}