Skip to main content

vortex_array/arrays/masked/
execute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Execution logic for MaskedArray - applies a validity mask to canonical arrays.
5
6use std::ops::BitAnd;
7
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10
11use crate::Canonical;
12use crate::IntoArray;
13use crate::arrays::BoolArray;
14use crate::arrays::DecimalArray;
15use crate::arrays::ExtensionArray;
16use crate::arrays::FixedSizeListArray;
17use crate::arrays::ListViewArray;
18use crate::arrays::NullArray;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::StructArray;
21use crate::arrays::VarBinViewArray;
22use crate::dtype::Nullability;
23use crate::executor::ExecutionCtx;
24use crate::match_each_decimal_value_type;
25use crate::validity::Validity;
26use crate::vtable::ValidityHelper;
27
28/// TODO: replace usage of compute fn.
29/// Apply a validity mask to a canonical array, ANDing with existing validity.
30///
31/// This is the core operation for MaskedArray execution - it intersects the child's
32/// validity with the provided mask, marking additional positions as invalid.
33pub fn mask_validity_canonical(
34    canonical: Canonical,
35    validity_mask: &Mask,
36    ctx: &mut ExecutionCtx,
37) -> VortexResult<Canonical> {
38    Ok(match canonical {
39        Canonical::Null(a) => Canonical::Null(mask_validity_null(a, validity_mask)),
40        Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity_mask, ctx)?),
41        Canonical::Primitive(a) => {
42            Canonical::Primitive(mask_validity_primitive(a, validity_mask, ctx)?)
43        }
44        Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity_mask, ctx)?),
45        Canonical::VarBinView(a) => {
46            Canonical::VarBinView(mask_validity_varbinview(a, validity_mask, ctx)?)
47        }
48        Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity_mask, ctx)?),
49        Canonical::FixedSizeList(a) => {
50            Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity_mask, ctx)?)
51        }
52        Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity_mask, ctx)?),
53        Canonical::Extension(a) => {
54            Canonical::Extension(mask_validity_extension(a, validity_mask, ctx)?)
55        }
56    })
57}
58
59fn combine_validity(
60    validity: &Validity,
61    mask: &Mask,
62    len: usize,
63    ctx: &mut ExecutionCtx,
64) -> VortexResult<Validity> {
65    let current_mask = validity.execute_mask(len, ctx)?;
66    let combined = current_mask.bitand(mask);
67    Ok(Validity::from_mask(combined, Nullability::Nullable))
68}
69
70fn mask_validity_null(array: NullArray, _mask: &Mask) -> NullArray {
71    array
72}
73
74fn mask_validity_bool(
75    array: BoolArray,
76    mask: &Mask,
77    ctx: &mut ExecutionCtx,
78) -> VortexResult<BoolArray> {
79    let len = array.len();
80    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
81    Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
82}
83
84fn mask_validity_primitive(
85    array: PrimitiveArray,
86    mask: &Mask,
87    ctx: &mut ExecutionCtx,
88) -> VortexResult<PrimitiveArray> {
89    let len = array.len();
90    let ptype = array.ptype();
91    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
92    // SAFETY: validity has same length as values
93    Ok(unsafe {
94        PrimitiveArray::new_unchecked_from_handle(
95            array.buffer_handle().clone(),
96            ptype,
97            new_validity,
98        )
99    })
100}
101
102fn mask_validity_decimal(
103    array: DecimalArray,
104    mask: &Mask,
105    ctx: &mut ExecutionCtx,
106) -> VortexResult<DecimalArray> {
107    let len = array.len();
108    let dec_dtype = array.decimal_dtype();
109    let values_type = array.values_type();
110    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
111    // SAFETY: We're only changing validity, not the data structure
112    Ok(match_each_decimal_value_type!(values_type, |T| {
113        let buffer = array.buffer::<T>();
114        unsafe { DecimalArray::new_unchecked(buffer, dec_dtype, new_validity) }
115    }))
116}
117
118/// Mask validity for VarBinViewArray.
119fn mask_validity_varbinview(
120    array: VarBinViewArray,
121    mask: &Mask,
122    ctx: &mut ExecutionCtx,
123) -> VortexResult<VarBinViewArray> {
124    let len = array.len();
125    let dtype = array.dtype().as_nullable();
126    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
127    // SAFETY: We're only changing validity, not the data structure
128    Ok(unsafe {
129        VarBinViewArray::new_handle_unchecked(
130            array.views_handle().clone(),
131            array.buffers().clone(),
132            dtype,
133            new_validity,
134        )
135    })
136}
137
138fn mask_validity_listview(
139    array: ListViewArray,
140    mask: &Mask,
141    ctx: &mut ExecutionCtx,
142) -> VortexResult<ListViewArray> {
143    let len = array.len();
144    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
145    // SAFETY: We're only changing validity, not the data structure
146    Ok(unsafe {
147        ListViewArray::new_unchecked(
148            array.elements().clone(),
149            array.offsets().clone(),
150            array.sizes().clone(),
151            new_validity,
152        )
153    })
154}
155
156fn mask_validity_fixed_size_list(
157    array: FixedSizeListArray,
158    mask: &Mask,
159    ctx: &mut ExecutionCtx,
160) -> VortexResult<FixedSizeListArray> {
161    let len = array.len();
162    let list_size = array.list_size();
163    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
164    // SAFETY: We're only changing validity, not the data structure
165    Ok(unsafe {
166        FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
167    })
168}
169
170fn mask_validity_struct(
171    array: StructArray,
172    mask: &Mask,
173    ctx: &mut ExecutionCtx,
174) -> VortexResult<StructArray> {
175    let len = array.len();
176    let new_validity = combine_validity(array.validity(), mask, len, ctx)?;
177    let fields = array.unmasked_fields().clone();
178    let struct_fields = array.struct_fields().clone();
179    // SAFETY: We're only changing validity, not the data structure
180    Ok(unsafe { StructArray::new_unchecked(fields, struct_fields, len, new_validity) })
181}
182
183fn mask_validity_extension(
184    array: ExtensionArray,
185    mask: &Mask,
186    ctx: &mut ExecutionCtx,
187) -> VortexResult<ExtensionArray> {
188    // For extension arrays, we need to mask the underlying storage
189    let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
190    let masked_storage = mask_validity_canonical(storage, mask, ctx)?;
191    let masked_storage = masked_storage.into_array();
192    Ok(ExtensionArray::new(
193        array
194            .ext_dtype()
195            .with_nullability(masked_storage.dtype().nullability()),
196        masked_storage,
197    ))
198}