Skip to main content

vortex_array/arrays/masked/
execute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Execution logic for MaskedArray - applies a validity mask to canonical arrays.
5
6use std::sync::Arc;
7
8use vortex_error::VortexResult;
9
10use crate::Canonical;
11use crate::IntoArray;
12use crate::arrays::BoolArray;
13use crate::arrays::DecimalArray;
14use crate::arrays::ExtensionArray;
15use crate::arrays::FixedSizeListArray;
16use crate::arrays::ListViewArray;
17use crate::arrays::MaskedArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::StructArray;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::VariantArray;
22use crate::arrays::bool::BoolArrayExt;
23use crate::arrays::extension::ExtensionArrayExt;
24use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
25use crate::arrays::listview::ListViewArrayExt;
26use crate::arrays::struct_::StructArrayExt;
27use crate::arrays::variant::VariantArrayExt;
28use crate::builtins::ArrayBuiltins;
29use crate::executor::ExecutionCtx;
30use crate::validity::Validity;
31
32/// TODO: replace usage of compute fn.
33/// Apply a validity mask to a canonical array, ANDing with existing validity.
34///
35/// This is the core operation for MaskedArray execution - it intersects the child's
36/// validity with the provided mask, marking additional positions as invalid.
37pub fn mask_validity_canonical(
38    canonical: Canonical,
39    validity: Validity,
40    ctx: &mut ExecutionCtx,
41) -> VortexResult<Canonical> {
42    Ok(match canonical {
43        n @ Canonical::Null(_) => n,
44        Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity)?),
45        Canonical::Primitive(a) => Canonical::Primitive(mask_validity_primitive(a, validity)?),
46        Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity)?),
47        Canonical::VarBinView(a) => Canonical::VarBinView(mask_validity_varbinview(a, validity)?),
48        Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity)?),
49        Canonical::FixedSizeList(a) => {
50            Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity)?)
51        }
52        Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity)?),
53        Canonical::Extension(a) => Canonical::Extension(mask_validity_extension(a, validity, ctx)?),
54        Canonical::Variant(a) => Canonical::Variant(mask_validity_variant(a, validity, ctx)?),
55    })
56}
57
58fn mask_validity_bool(array: BoolArray, mask: Validity) -> VortexResult<BoolArray> {
59    let new_validity = Validity::and(array.validity()?, mask)?;
60    Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
61}
62
63fn mask_validity_primitive(
64    array: PrimitiveArray,
65    validity: Validity,
66) -> VortexResult<PrimitiveArray> {
67    let ptype = array.ptype();
68    let new_validity = Validity::and(array.validity()?, validity)?;
69    // SAFETY: validity has same length as values
70    Ok(unsafe {
71        PrimitiveArray::new_unchecked_from_handle(
72            array.buffer_handle().clone(),
73            ptype,
74            new_validity,
75        )
76    })
77}
78
79fn mask_validity_decimal(array: DecimalArray, validity: Validity) -> VortexResult<DecimalArray> {
80    let new_validity = Validity::and(array.validity()?, validity)?;
81    // SAFETY: We're only changing validity, not the data structure
82    Ok(unsafe {
83        DecimalArray::new_unchecked_handle(
84            array.buffer_handle().clone(),
85            array.values_type(),
86            array.decimal_dtype(),
87            new_validity,
88        )
89    })
90}
91
92/// Mask validity for VarBinViewArray.
93fn mask_validity_varbinview(
94    array: VarBinViewArray,
95    validity: Validity,
96) -> VortexResult<VarBinViewArray> {
97    let dtype = array.dtype().as_nullable();
98    let new_validity = Validity::and(array.validity()?, validity)?;
99    // SAFETY: We're only changing validity, not the data structure
100    Ok(unsafe {
101        VarBinViewArray::new_handle_unchecked(
102            array.views_handle().clone(),
103            Arc::clone(array.data_buffers()),
104            dtype,
105            new_validity,
106        )
107    })
108}
109
110fn mask_validity_listview(array: ListViewArray, validity: Validity) -> VortexResult<ListViewArray> {
111    let new_validity = Validity::and(array.validity()?, validity)?;
112    // SAFETY: We're only changing validity, not the data structure
113    Ok(unsafe {
114        ListViewArray::new_unchecked(
115            array.elements().clone(),
116            array.offsets().clone(),
117            array.sizes().clone(),
118            new_validity,
119        )
120    })
121}
122
123fn mask_validity_fixed_size_list(
124    array: FixedSizeListArray,
125    validity: Validity,
126) -> VortexResult<FixedSizeListArray> {
127    let len = array.len();
128    let list_size = array.list_size();
129    let new_validity = Validity::and(array.validity()?, validity)?;
130    // SAFETY: We're only changing validity, not the data structure
131    Ok(unsafe {
132        FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
133    })
134}
135
136fn mask_validity_struct(array: StructArray, validity: Validity) -> VortexResult<StructArray> {
137    let len = array.len();
138    let new_validity = Validity::and(array.validity()?, validity)?;
139    let fields = array.unmasked_fields();
140    let struct_fields = array.struct_fields();
141    // SAFETY: We're only changing validity, not the data structure
142    Ok(unsafe { StructArray::new_unchecked(fields, struct_fields.clone(), len, new_validity) })
143}
144
145fn mask_validity_extension(
146    array: ExtensionArray,
147    validity: Validity,
148    ctx: &mut ExecutionCtx,
149) -> VortexResult<ExtensionArray> {
150    // For extension arrays, we need to mask the underlying storage
151    let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
152    let masked_storage = mask_validity_canonical(storage, validity, ctx)?;
153    let masked_storage = masked_storage.into_array();
154    Ok(ExtensionArray::new(
155        array
156            .ext_dtype()
157            .with_nullability(masked_storage.dtype().nullability()),
158        masked_storage,
159    ))
160}
161
162fn mask_validity_variant(
163    array: VariantArray,
164    validity: Validity,
165    ctx: &mut ExecutionCtx,
166) -> VortexResult<VariantArray> {
167    let core_storage = array.core_storage().clone();
168    let len = core_storage.len();
169    let core_validity = core_storage.validity()?;
170    let shredded_validity = validity.clone();
171
172    let masked_core_storage = match core_validity {
173        Validity::NonNullable | Validity::AllValid => {
174            // Core storage has no nulls, so wrap it in MaskedArray to apply the mask.
175            MaskedArray::try_new(core_storage, validity)?.into_array()
176        }
177        Validity::AllInvalid => {
178            // Already all-null, ANDing with any mask is still all-null.
179            core_storage
180        }
181        Validity::Array(_) => {
182            // Core storage already has nulls, but its physical validity layout depends on the
183            // actual encoding. Use the mask operation instead of rewriting a presumed slot.
184            core_storage.mask(validity.to_array(len))?
185        }
186    };
187    let masked_shredded = if let Some(shredded) = array.shredded() {
188        let canonical = shredded.clone().execute::<Canonical>(ctx)?;
189        Some(mask_validity_canonical(canonical, shredded_validity, ctx)?.into_array())
190    } else {
191        None
192    };
193
194    VariantArray::try_new(masked_core_storage, masked_shredded)
195}