vortex_array/arrays/masked/
execute.rs1use std::sync::Arc;
7
8use vortex_error::VortexResult;
9
10use crate::Canonical;
11use crate::IntoArray;
12use crate::arrays::BoolArray;
13use crate::arrays::DecimalArray;
14use crate::arrays::ExtensionArray;
15use crate::arrays::FixedSizeListArray;
16use crate::arrays::ListViewArray;
17use crate::arrays::MaskedArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::StructArray;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::VariantArray;
22use crate::arrays::bool::BoolArrayExt;
23use crate::arrays::extension::ExtensionArrayExt;
24use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
25use crate::arrays::listview::ListViewArrayExt;
26use crate::arrays::struct_::StructArrayExt;
27use crate::arrays::variant::VariantArrayExt;
28use crate::builtins::ArrayBuiltins;
29use crate::executor::ExecutionCtx;
30use crate::validity::Validity;
31
32pub fn mask_validity_canonical(
38 canonical: Canonical,
39 validity: Validity,
40 ctx: &mut ExecutionCtx,
41) -> VortexResult<Canonical> {
42 Ok(match canonical {
43 n @ Canonical::Null(_) => n,
44 Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity)?),
45 Canonical::Primitive(a) => Canonical::Primitive(mask_validity_primitive(a, validity)?),
46 Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity)?),
47 Canonical::VarBinView(a) => Canonical::VarBinView(mask_validity_varbinview(a, validity)?),
48 Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity)?),
49 Canonical::FixedSizeList(a) => {
50 Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity)?)
51 }
52 Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity)?),
53 Canonical::Extension(a) => Canonical::Extension(mask_validity_extension(a, validity, ctx)?),
54 Canonical::Variant(a) => Canonical::Variant(mask_validity_variant(a, validity, ctx)?),
55 })
56}
57
58fn mask_validity_bool(array: BoolArray, mask: Validity) -> VortexResult<BoolArray> {
59 let new_validity = Validity::and(array.validity()?, mask)?;
60 Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
61}
62
63fn mask_validity_primitive(
64 array: PrimitiveArray,
65 validity: Validity,
66) -> VortexResult<PrimitiveArray> {
67 let ptype = array.ptype();
68 let new_validity = Validity::and(array.validity()?, validity)?;
69 Ok(unsafe {
71 PrimitiveArray::new_unchecked_from_handle(
72 array.buffer_handle().clone(),
73 ptype,
74 new_validity,
75 )
76 })
77}
78
79fn mask_validity_decimal(array: DecimalArray, validity: Validity) -> VortexResult<DecimalArray> {
80 let new_validity = Validity::and(array.validity()?, validity)?;
81 Ok(unsafe {
83 DecimalArray::new_unchecked_handle(
84 array.buffer_handle().clone(),
85 array.values_type(),
86 array.decimal_dtype(),
87 new_validity,
88 )
89 })
90}
91
92fn mask_validity_varbinview(
94 array: VarBinViewArray,
95 validity: Validity,
96) -> VortexResult<VarBinViewArray> {
97 let dtype = array.dtype().as_nullable();
98 let new_validity = Validity::and(array.validity()?, validity)?;
99 Ok(unsafe {
101 VarBinViewArray::new_handle_unchecked(
102 array.views_handle().clone(),
103 Arc::clone(array.data_buffers()),
104 dtype,
105 new_validity,
106 )
107 })
108}
109
110fn mask_validity_listview(array: ListViewArray, validity: Validity) -> VortexResult<ListViewArray> {
111 let new_validity = Validity::and(array.validity()?, validity)?;
112 Ok(unsafe {
114 ListViewArray::new_unchecked(
115 array.elements().clone(),
116 array.offsets().clone(),
117 array.sizes().clone(),
118 new_validity,
119 )
120 })
121}
122
123fn mask_validity_fixed_size_list(
124 array: FixedSizeListArray,
125 validity: Validity,
126) -> VortexResult<FixedSizeListArray> {
127 let len = array.len();
128 let list_size = array.list_size();
129 let new_validity = Validity::and(array.validity()?, validity)?;
130 Ok(unsafe {
132 FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
133 })
134}
135
136fn mask_validity_struct(array: StructArray, validity: Validity) -> VortexResult<StructArray> {
137 let len = array.len();
138 let new_validity = Validity::and(array.validity()?, validity)?;
139 let fields = array.unmasked_fields();
140 let struct_fields = array.struct_fields();
141 Ok(unsafe { StructArray::new_unchecked(fields, struct_fields.clone(), len, new_validity) })
143}
144
145fn mask_validity_extension(
146 array: ExtensionArray,
147 validity: Validity,
148 ctx: &mut ExecutionCtx,
149) -> VortexResult<ExtensionArray> {
150 let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
152 let masked_storage = mask_validity_canonical(storage, validity, ctx)?;
153 let masked_storage = masked_storage.into_array();
154 Ok(ExtensionArray::new(
155 array
156 .ext_dtype()
157 .with_nullability(masked_storage.dtype().nullability()),
158 masked_storage,
159 ))
160}
161
162fn mask_validity_variant(
163 array: VariantArray,
164 validity: Validity,
165 ctx: &mut ExecutionCtx,
166) -> VortexResult<VariantArray> {
167 let core_storage = array.core_storage().clone();
168 let len = core_storage.len();
169 let core_validity = core_storage.validity()?;
170 let shredded_validity = validity.clone();
171
172 let masked_core_storage = match core_validity {
173 Validity::NonNullable | Validity::AllValid => {
174 MaskedArray::try_new(core_storage, validity)?.into_array()
176 }
177 Validity::AllInvalid => {
178 core_storage
180 }
181 Validity::Array(_) => {
182 core_storage.mask(validity.to_array(len))?
185 }
186 };
187 let masked_shredded = if let Some(shredded) = array.shredded() {
188 let canonical = shredded.clone().execute::<Canonical>(ctx)?;
189 Some(mask_validity_canonical(canonical, shredded_validity, ctx)?.into_array())
190 } else {
191 None
192 };
193
194 VariantArray::try_new(masked_core_storage, masked_shredded)
195}