Skip to main content

vortex_array/arrays/masked/
execute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Execution logic for MaskedArray - applies a validity mask to canonical arrays.
5
6use std::ops::BitAnd;
7
8use vortex_dtype::Nullability;
9use vortex_dtype::match_each_decimal_value_type;
10use vortex_error::VortexResult;
11use vortex_mask::Mask;
12
13use crate::Canonical;
14use crate::IntoArray;
15use crate::arrays::BoolArray;
16use crate::arrays::DecimalArray;
17use crate::arrays::ExtensionArray;
18use crate::arrays::FixedSizeListArray;
19use crate::arrays::ListViewArray;
20use crate::arrays::NullArray;
21use crate::arrays::PrimitiveArray;
22use crate::arrays::StructArray;
23use crate::arrays::VarBinViewArray;
24use crate::executor::ExecutionCtx;
25use crate::validity::Validity;
26use crate::vtable::ValidityHelper;
27
28/// TODO: replace usage of compute fn.
29/// Apply a validity mask to a canonical array, ANDing with existing validity.
30///
31/// This is the core operation for MaskedArray execution - it intersects the child's
32/// validity with the provided mask, marking additional positions as invalid.
33pub fn mask_validity_canonical(
34    canonical: Canonical,
35    validity_mask: &Mask,
36    ctx: &mut ExecutionCtx,
37) -> VortexResult<Canonical> {
38    Ok(match canonical {
39        Canonical::Null(a) => Canonical::Null(mask_validity_null(a, validity_mask)),
40        Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity_mask)),
41        Canonical::Primitive(a) => Canonical::Primitive(mask_validity_primitive(a, validity_mask)),
42        Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity_mask)),
43        Canonical::VarBinView(a) => {
44            Canonical::VarBinView(mask_validity_varbinview(a, validity_mask))
45        }
46        Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity_mask)),
47        Canonical::FixedSizeList(a) => {
48            Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity_mask))
49        }
50        Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity_mask)),
51        Canonical::Extension(a) => {
52            Canonical::Extension(mask_validity_extension(a, validity_mask, ctx)?)
53        }
54    })
55}
56
57fn combine_validity(validity: &Validity, mask: &Mask, len: usize) -> Validity {
58    let current_mask = validity.to_mask(len);
59    let combined = current_mask.bitand(mask);
60    Validity::from_mask(combined, Nullability::Nullable)
61}
62
63fn mask_validity_null(array: NullArray, _mask: &Mask) -> NullArray {
64    array
65}
66
67fn mask_validity_bool(array: BoolArray, mask: &Mask) -> BoolArray {
68    let len = array.len();
69    let new_validity = combine_validity(array.validity(), mask, len);
70    BoolArray::new(array.to_bit_buffer(), new_validity)
71}
72
73fn mask_validity_primitive(array: PrimitiveArray, mask: &Mask) -> PrimitiveArray {
74    let len = array.len();
75    let ptype = array.ptype();
76    let new_validity = combine_validity(array.validity(), mask, len);
77    // SAFETY: validity has same length as values
78    unsafe {
79        PrimitiveArray::new_unchecked_from_handle(
80            array.buffer_handle().clone(),
81            ptype,
82            new_validity,
83        )
84    }
85}
86
87fn mask_validity_decimal(array: DecimalArray, mask: &Mask) -> DecimalArray {
88    let len = array.len();
89    let dec_dtype = array.decimal_dtype();
90    let values_type = array.values_type();
91    let new_validity = combine_validity(array.validity(), mask, len);
92    // SAFETY: We're only changing validity, not the data structure
93    match_each_decimal_value_type!(values_type, |T| {
94        let buffer = array.buffer::<T>();
95        unsafe { DecimalArray::new_unchecked(buffer, dec_dtype, new_validity) }
96    })
97}
98
99/// Mask validity for VarBinViewArray.
100fn mask_validity_varbinview(array: VarBinViewArray, mask: &Mask) -> VarBinViewArray {
101    let len = array.len();
102    let dtype = array.dtype().as_nullable();
103    let new_validity = combine_validity(array.validity(), mask, len);
104    // SAFETY: We're only changing validity, not the data structure
105    unsafe {
106        VarBinViewArray::new_handle_unchecked(
107            array.views_handle().clone(),
108            array.buffers().clone(),
109            dtype,
110            new_validity,
111        )
112    }
113}
114
115fn mask_validity_listview(array: ListViewArray, mask: &Mask) -> ListViewArray {
116    let len = array.len();
117    let new_validity = combine_validity(array.validity(), mask, len);
118    // SAFETY: We're only changing validity, not the data structure
119    unsafe {
120        ListViewArray::new_unchecked(
121            array.elements().clone(),
122            array.offsets().clone(),
123            array.sizes().clone(),
124            new_validity,
125        )
126    }
127}
128
129fn mask_validity_fixed_size_list(array: FixedSizeListArray, mask: &Mask) -> FixedSizeListArray {
130    let len = array.len();
131    let list_size = array.list_size();
132    let new_validity = combine_validity(array.validity(), mask, len);
133    // SAFETY: We're only changing validity, not the data structure
134    unsafe {
135        FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
136    }
137}
138
139fn mask_validity_struct(array: StructArray, mask: &Mask) -> StructArray {
140    let len = array.len();
141    let new_validity = combine_validity(array.validity(), mask, len);
142    let fields = array.unmasked_fields().clone();
143    let struct_fields = array.struct_fields().clone();
144    // SAFETY: We're only changing validity, not the data structure
145    unsafe { StructArray::new_unchecked(fields, struct_fields, len, new_validity) }
146}
147
148fn mask_validity_extension(
149    array: ExtensionArray,
150    mask: &Mask,
151    ctx: &mut ExecutionCtx,
152) -> VortexResult<ExtensionArray> {
153    // For extension arrays, we need to mask the underlying storage
154    let storage = array.storage().clone().execute::<Canonical>(ctx)?;
155    let masked_storage = mask_validity_canonical(storage, mask, ctx)?;
156    let masked_storage = masked_storage.into_array();
157    Ok(ExtensionArray::new(
158        array
159            .ext_dtype()
160            .with_nullability(masked_storage.dtype().nullability()),
161        masked_storage,
162    ))
163}