vortex_array/arrays/masked/
execute.rs1use std::sync::Arc;
7
8use vortex_error::VortexResult;
9
10use crate::Canonical;
11use crate::IntoArray;
12use crate::arrays::BoolArray;
13use crate::arrays::DecimalArray;
14use crate::arrays::ExtensionArray;
15use crate::arrays::FixedSizeListArray;
16use crate::arrays::ListViewArray;
17use crate::arrays::MaskedArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::StructArray;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::VariantArray;
22use crate::arrays::bool::BoolArrayExt;
23use crate::arrays::extension::ExtensionArrayExt;
24use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
25use crate::arrays::listview::ListViewArrayExt;
26use crate::arrays::struct_::StructArrayExt;
27use crate::arrays::variant::VariantArrayExt;
28use crate::executor::ExecutionCtx;
29use crate::validity::Validity;
30
31pub fn mask_validity_canonical(
37 canonical: Canonical,
38 validity: Validity,
39 ctx: &mut ExecutionCtx,
40) -> VortexResult<Canonical> {
41 Ok(match canonical {
42 n @ Canonical::Null(_) => n,
43 Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity)?),
44 Canonical::Primitive(a) => Canonical::Primitive(mask_validity_primitive(a, validity)?),
45 Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity)?),
46 Canonical::VarBinView(a) => Canonical::VarBinView(mask_validity_varbinview(a, validity)?),
47 Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity)?),
48 Canonical::FixedSizeList(a) => {
49 Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity)?)
50 }
51 Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity)?),
52 Canonical::Extension(a) => Canonical::Extension(mask_validity_extension(a, validity, ctx)?),
53 Canonical::Variant(a) => Canonical::Variant(mask_validity_variant(a, validity)?),
54 })
55}
56
57fn mask_validity_bool(array: BoolArray, mask: Validity) -> VortexResult<BoolArray> {
58 let new_validity = Validity::and(array.validity()?, mask)?;
59 Ok(BoolArray::new(array.to_bit_buffer(), new_validity))
60}
61
62fn mask_validity_primitive(
63 array: PrimitiveArray,
64 validity: Validity,
65) -> VortexResult<PrimitiveArray> {
66 let ptype = array.ptype();
67 let new_validity = Validity::and(array.validity()?, validity)?;
68 Ok(unsafe {
70 PrimitiveArray::new_unchecked_from_handle(
71 array.buffer_handle().clone(),
72 ptype,
73 new_validity,
74 )
75 })
76}
77
78fn mask_validity_decimal(array: DecimalArray, validity: Validity) -> VortexResult<DecimalArray> {
79 let new_validity = Validity::and(array.validity()?, validity)?;
80 Ok(unsafe {
82 DecimalArray::new_unchecked_handle(
83 array.buffer_handle().clone(),
84 array.values_type(),
85 array.decimal_dtype(),
86 new_validity,
87 )
88 })
89}
90
91fn mask_validity_varbinview(
93 array: VarBinViewArray,
94 validity: Validity,
95) -> VortexResult<VarBinViewArray> {
96 let dtype = array.dtype().as_nullable();
97 let new_validity = Validity::and(array.validity()?, validity)?;
98 Ok(unsafe {
100 VarBinViewArray::new_handle_unchecked(
101 array.views_handle().clone(),
102 Arc::clone(array.data_buffers()),
103 dtype,
104 new_validity,
105 )
106 })
107}
108
109fn mask_validity_listview(array: ListViewArray, validity: Validity) -> VortexResult<ListViewArray> {
110 let new_validity = Validity::and(array.validity()?, validity)?;
111 Ok(unsafe {
113 ListViewArray::new_unchecked(
114 array.elements().clone(),
115 array.offsets().clone(),
116 array.sizes().clone(),
117 new_validity,
118 )
119 })
120}
121
122fn mask_validity_fixed_size_list(
123 array: FixedSizeListArray,
124 validity: Validity,
125) -> VortexResult<FixedSizeListArray> {
126 let len = array.len();
127 let list_size = array.list_size();
128 let new_validity = Validity::and(array.validity()?, validity)?;
129 Ok(unsafe {
131 FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len)
132 })
133}
134
135fn mask_validity_struct(array: StructArray, validity: Validity) -> VortexResult<StructArray> {
136 let len = array.len();
137 let new_validity = Validity::and(array.validity()?, validity)?;
138 let fields = array.unmasked_fields();
139 let struct_fields = array.struct_fields();
140 Ok(unsafe { StructArray::new_unchecked(fields, struct_fields.clone(), len, new_validity) })
142}
143
144fn mask_validity_extension(
145 array: ExtensionArray,
146 validity: Validity,
147 ctx: &mut ExecutionCtx,
148) -> VortexResult<ExtensionArray> {
149 let storage = array.storage_array().clone().execute::<Canonical>(ctx)?;
151 let masked_storage = mask_validity_canonical(storage, validity, ctx)?;
152 let masked_storage = masked_storage.into_array();
153 Ok(ExtensionArray::new(
154 array
155 .ext_dtype()
156 .with_nullability(masked_storage.dtype().nullability()),
157 masked_storage,
158 ))
159}
160
161fn mask_validity_variant(array: VariantArray, validity: Validity) -> VortexResult<VariantArray> {
162 let child = array.child().clone();
163 let len = child.len();
164 let child_validity = child.validity()?;
165
166 match child_validity {
167 Validity::NonNullable | Validity::AllValid => {
168 let masked_child = MaskedArray::try_new(child, validity)?;
170 Ok(VariantArray::new(masked_child.into_array()))
171 }
172 Validity::AllInvalid => {
173 Ok(array)
175 }
176 Validity::Array(_) => {
177 let combined = Validity::and(child_validity, validity)?;
180 let new_child = child.with_slot(0, combined.to_array(len))?;
181 Ok(VariantArray::new(new_child))
182 }
183 }
184}