Skip to main content

vortex_array/arrays/masked/
execute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Execution logic for MaskedArray - applies a validity mask to canonical arrays.
5
6use std::sync::Arc;
7
8use vortex_error::VortexResult;
9
10use crate::Canonical;
11use crate::IntoArray;
12use crate::arrays::BoolArray;
13use crate::arrays::DecimalArray;
14use crate::arrays::ExtensionArray;
15use crate::arrays::FixedSizeListArray;
16use crate::arrays::ListViewArray;
17use crate::arrays::MaskedArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::StructArray;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::VariantArray;
22use crate::arrays::bool::BoolArrayExt;
23use crate::arrays::extension::ExtensionArrayExt;
24use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
25use crate::arrays::listview::ListViewArrayExt;
26use crate::arrays::struct_::StructArrayExt;
27use crate::arrays::variant::VariantArrayExt;
28use crate::executor::ExecutionCtx;
29use crate::validity::Validity;
30
31/// TODO: replace usage of compute fn.
32/// Apply a validity mask to a canonical array, ANDing with existing validity.
33///
34/// This is the core operation for MaskedArray execution - it intersects the child's
35/// validity with the provided mask, marking additional positions as invalid.
36pub fn mask_validity_canonical(
37    canonical: Canonical,
38    validity: Validity,
39    ctx: &mut ExecutionCtx,
40) -> VortexResult<Canonical> {
41    Ok(match canonical {
42        n @ Canonical::Null(_) => n,
43        Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity)?),
44        Canonical::Primitive(a) => Canonical::Primitive(mask_validity_primitive(a, validity)?),
45        Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity)?),
46        Canonical::VarBinView(a) => Canonical::VarBinView(mask_validity_varbinview(a, validity)?),
47        Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity)?),
48        Canonical::FixedSizeList(a) => {
49            Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity)?)
50        }
51        Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity)?),
52        Canonical::Extension(a) => Canonical::Extension(mask_validity_extension(a, validity, ctx)?),
53        Canonical::Variant(a) => Canonical::Variant(mask_validity_variant(a, validity)?),
54    })
55}
56
57fn mask_validity_bool(array: BoolArray, mask: Validity) -> VortexResult<BoolArray> {
58    let new_validity = Validity::and(array.validity()?, mask)?;
59    Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
60}
61
62fn mask_validity_primitive(
63    array: PrimitiveArray,
64    validity: Validity,
65) -> VortexResult<PrimitiveArray> {
66    let ptype = array.ptype();
67    let new_validity = Validity::and(array.validity()?, validity)?;
68    // SAFETY: validity has same length as values
69    Ok(unsafe {
70        PrimitiveArray::new_unchecked_from_handle(
71            array.buffer_handle().clone(),
72            ptype,
73            new_validity,
74        )
75    })
76}
77
78fn mask_validity_decimal(array: DecimalArray, validity: Validity) -> VortexResult<DecimalArray> {
79    let new_validity = Validity::and(array.validity()?, validity)?;
80    // SAFETY: We're only changing validity, not the data structure
81    Ok(unsafe {
82        DecimalArray::new_unchecked_handle(
83            array.buffer_handle().clone(),
84            array.values_type(),
85            array.decimal_dtype(),
86            new_validity,
87        )
88    })
89}
90
91/// Mask validity for VarBinViewArray.
92fn mask_validity_varbinview(
93    array: VarBinViewArray,
94    validity: Validity,
95) -> VortexResult<VarBinViewArray> {
96    let dtype = array.dtype().as_nullable();
97    let new_validity = Validity::and(array.validity()?, validity)?;
98    // SAFETY: We're only changing validity, not the data structure
99    Ok(unsafe {
100        VarBinViewArray::new_handle_unchecked(
101            array.views_handle().clone(),
102            Arc::clone(array.data_buffers()),
103            dtype,
104            new_validity,
105        )
106    })
107}
108
109fn mask_validity_listview(array: ListViewArray, validity: Validity) -> VortexResult<ListViewArray> {
110    let new_validity = Validity::and(array.validity()?, validity)?;
111    // SAFETY: We're only changing validity, not the data structure
112    Ok(unsafe {
113        ListViewArray::new_unchecked(
114            array.elements().clone(),
115            array.offsets().clone(),
116            array.sizes().clone(),
117            new_validity,
118        )
119    })
120}
121
122fn mask_validity_fixed_size_list(
123    array: FixedSizeListArray,
124    validity: Validity,
125) -> VortexResult<FixedSizeListArray> {
126    let len = array.len();
127    let list_size = array.list_size();
128    let new_validity = Validity::and(array.validity()?, validity)?;
129    // SAFETY: We're only changing validity, not the data structure
130    Ok(unsafe {
131        FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
132    })
133}
134
135fn mask_validity_struct(array: StructArray, validity: Validity) -> VortexResult<StructArray> {
136    let len = array.len();
137    let new_validity = Validity::and(array.validity()?, validity)?;
138    let fields = array.unmasked_fields();
139    let struct_fields = array.struct_fields();
140    // SAFETY: We're only changing validity, not the data structure
141    Ok(unsafe { StructArray::new_unchecked(fields, struct_fields.clone(), len, new_validity) })
142}
143
144fn mask_validity_extension(
145    array: ExtensionArray,
146    validity: Validity,
147    ctx: &mut ExecutionCtx,
148) -> VortexResult<ExtensionArray> {
149    // For extension arrays, we need to mask the underlying storage
150    let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
151    let masked_storage = mask_validity_canonical(storage, validity, ctx)?;
152    let masked_storage = masked_storage.into_array();
153    Ok(ExtensionArray::new(
154        array
155            .ext_dtype()
156            .with_nullability(masked_storage.dtype().nullability()),
157        masked_storage,
158    ))
159}
160
161fn mask_validity_variant(array: VariantArray, validity: Validity) -> VortexResult<VariantArray> {
162    let child = array.child().clone();
163    let len = child.len();
164    let child_validity = child.validity()?;
165
166    match child_validity {
167        Validity::NonNullable | Validity::AllValid => {
168            // Child has no nulls — wrap in MaskedArray to apply the mask.
169            let masked_child = MaskedArray::try_new(child, validity)?;
170            Ok(VariantArray::new(masked_child.into_array()))
171        }
172        Validity::AllInvalid => {
173            // Already all-null, ANDing with any mask is still all-null.
174            Ok(array)
175        }
176        Validity::Array(_) => {
177            // Child has an array-backed validity stored as its first child.
178            // Combine with the mask and replace that child via with_children.
179            let combined = Validity::and(child_validity, validity)?;
180            let new_child = child.with_slot(0, combined.to_array(len))?;
181            Ok(VariantArray::new(new_child))
182        }
183    }
184}