Skip to main content

vortex_array/arrays/masked/
execute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Execution logic for MaskedArray - applies a validity mask to canonical arrays.
5
6use std::ops::BitAnd;
7
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_mask::Mask;
11
12use crate::Canonical;
13use crate::IntoArray;
14use crate::arrays::BoolArray;
15use crate::arrays::DecimalArray;
16use crate::arrays::ExtensionArray;
17use crate::arrays::FixedSizeListArray;
18use crate::arrays::ListViewArray;
19use crate::arrays::NullArray;
20use crate::arrays::PrimitiveArray;
21use crate::arrays::StructArray;
22use crate::arrays::VarBinViewArray;
23use crate::dtype::Nullability;
24use crate::executor::ExecutionCtx;
25use crate::match_each_decimal_value_type;
26use crate::validity::Validity;
27use crate::vtable::ValidityHelper;
28
29/// TODO: replace usage of compute fn.
30/// Apply a validity mask to a canonical array, ANDing with existing validity.
31///
32/// This is the core operation for MaskedArray execution - it intersects the child's
33/// validity with the provided mask, marking additional positions as invalid.
34pub fn mask_validity_canonical(
35    canonical: Canonical,
36    validity_mask: &Mask,
37    ctx: &mut ExecutionCtx,
38) -> VortexResult<Canonical> {
39    Ok(match canonical {
40        Canonical::Null(a) => Canonical::Null(mask_validity_null(a, validity_mask)),
41        Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity_mask, ctx)?),
42        Canonical::Primitive(a) => {
43            Canonical::Primitive(mask_validity_primitive(a, validity_mask, ctx)?)
44        }
45        Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity_mask, ctx)?),
46        Canonical::VarBinView(a) => {
47            Canonical::VarBinView(mask_validity_varbinview(a, validity_mask, ctx)?)
48        }
49        Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity_mask, ctx)?),
50        Canonical::FixedSizeList(a) => {
51            Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity_mask, ctx)?)
52        }
53        Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity_mask, ctx)?),
54        Canonical::Extension(a) => {
55            Canonical::Extension(mask_validity_extension(a, validity_mask, ctx)?)
56        }
57        Canonical::Variant(_) => {
58            vortex_bail!("Variant arrays don't masking validity")
59        }
60    })
61}
62
63fn combine_validity(
64    validity: &Validity,
65    mask: &Mask,
66    len: usize,
67    ctx: &mut ExecutionCtx,
68) -> VortexResult<Validity> {
69    let current_mask = validity.execute_mask(len, ctx)?;
70    let combined = current_mask.bitand(mask);
71    Ok(Validity::from_mask(combined, Nullability::Nullable))
72}
73
74fn mask_validity_null(array: NullArray, _mask: &Mask) -> NullArray {
75    array
76}
77
78fn mask_validity_bool(
79    array: BoolArray,
80    mask: &Mask,
81    ctx: &mut ExecutionCtx,
82) -> VortexResult<BoolArray> {
83    let len = array.len();
84    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
85    Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
86}
87
88fn mask_validity_primitive(
89    array: PrimitiveArray,
90    mask: &Mask,
91    ctx: &mut ExecutionCtx,
92) -> VortexResult<PrimitiveArray> {
93    let len = array.len();
94    let ptype = array.ptype();
95    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
96    // SAFETY: validity has same length as values
97    Ok(unsafe {
98        PrimitiveArray::new_unchecked_from_handle(
99            array.buffer_handle().clone(),
100            ptype,
101            new_validity,
102        )
103    })
104}
105
106fn mask_validity_decimal(
107    array: DecimalArray,
108    mask: &Mask,
109    ctx: &mut ExecutionCtx,
110) -> VortexResult<DecimalArray> {
111    let len = array.len();
112    let dec_dtype = array.decimal_dtype();
113    let values_type = array.values_type();
114    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
115    // SAFETY: We're only changing validity, not the data structure
116    Ok(match_each_decimal_value_type!(values_type, |T| {
117        let buffer = array.buffer::<T>();
118        unsafe { DecimalArray::new_unchecked(buffer, dec_dtype, new_validity) }
119    }))
120}
121
122/// Mask validity for VarBinViewArray.
123fn mask_validity_varbinview(
124    array: VarBinViewArray,
125    mask: &Mask,
126    ctx: &mut ExecutionCtx,
127) -> VortexResult<VarBinViewArray> {
128    let len = array.len();
129    let dtype = array.dtype().as_nullable();
130    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
131    // SAFETY: We're only changing validity, not the data structure
132    Ok(unsafe {
133        VarBinViewArray::new_handle_unchecked(
134            array.views_handle().clone(),
135            array.buffers().clone(),
136            dtype,
137            new_validity,
138        )
139    })
140}
141
142fn mask_validity_listview(
143    array: ListViewArray,
144    mask: &Mask,
145    ctx: &mut ExecutionCtx,
146) -> VortexResult<ListViewArray> {
147    let len = array.len();
148    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
149    // SAFETY: We're only changing validity, not the data structure
150    Ok(unsafe {
151        ListViewArray::new_unchecked(
152            array.elements().clone(),
153            array.offsets().clone(),
154            array.sizes().clone(),
155            new_validity,
156        )
157    })
158}
159
160fn mask_validity_fixed_size_list(
161    array: FixedSizeListArray,
162    mask: &Mask,
163    ctx: &mut ExecutionCtx,
164) -> VortexResult<FixedSizeListArray> {
165    let len = array.len();
166    let list_size = array.list_size();
167    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
168    // SAFETY: We're only changing validity, not the data structure
169    Ok(unsafe {
170        FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
171    })
172}
173
174fn mask_validity_struct(
175    array: StructArray,
176    mask: &Mask,
177    ctx: &mut ExecutionCtx,
178) -> VortexResult<StructArray> {
179    let len = array.len();
180    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
181    let fields = array.unmasked_fields().clone();
182    let struct_fields = array.struct_fields().clone();
183    // SAFETY: We're only changing validity, not the data structure
184    Ok(unsafe { StructArray::new_unchecked(fields, struct_fields, len, new_validity) })
185}
186
187fn mask_validity_extension(
188    array: ExtensionArray,
189    mask: &Mask,
190    ctx: &mut ExecutionCtx,
191) -> VortexResult<ExtensionArray> {
192    // For extension arrays, we need to mask the underlying storage
193    let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
194    let masked_storage = mask_validity_canonical(storage, mask, ctx)?;
195    let masked_storage = masked_storage.into_array();
196    Ok(ExtensionArray::new(
197        array
198            .ext_dtype()
199            .with_nullability(masked_storage.dtype().nullability()),
200        masked_storage,
201    ))
202}