vortex_array/arrays/masked/vtable/
mod.rs1mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hasher;
8
9use smallvec::smallvec;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_error::vortex_panic;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use crate::AnyCanonical;
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayParts;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::EqMode;
25use crate::IntoArray;
26use crate::LEGACY_SESSION;
27use crate::VortexSessionExecute;
28use crate::array::Array;
29use crate::array::ArrayId;
30use crate::array::ArrayView;
31use crate::array::VTable;
32use crate::array::validity_to_child;
33use crate::array::with_empty_buffers;
34use crate::arrays::ConstantArray;
35use crate::arrays::masked::MaskedArrayExt;
36use crate::arrays::masked::MaskedArraySlotsExt;
37use crate::arrays::masked::MaskedData;
38use crate::arrays::masked::array::MaskedSlots;
39use crate::arrays::masked::compute::rules::PARENT_RULES;
40use crate::arrays::masked::mask_validity_canonical;
41use crate::buffer::BufferHandle;
42use crate::dtype::DType;
43use crate::executor::ExecutionCtx;
44use crate::executor::ExecutionResult;
45use crate::require_child;
46use crate::scalar::Scalar;
47use crate::serde::ArrayChildren;
48use crate::validity::Validity;
49pub type MaskedArray = Array<Masked>;
51
52#[derive(Clone, Debug)]
53pub struct Masked;
54
55impl ArrayHash for MaskedData {
56 fn array_hash<H: Hasher>(&self, _state: &mut H, _accuracy: EqMode) {}
57}
58
59impl ArrayEq for MaskedData {
60 fn array_eq(&self, _other: &Self, _accuracy: EqMode) -> bool {
61 true
62 }
63}
64
65impl VTable for Masked {
66 type TypedArrayData = MaskedData;
67
68 type OperationsVTable = Self;
69 type ValidityVTable = Self;
70
71 fn id(&self) -> ArrayId {
72 static ID: CachedId = CachedId::new("vortex.masked");
73 *ID
74 }
75
76 fn validate(
77 &self,
78 _data: &MaskedData,
79 dtype: &DType,
80 len: usize,
81 slots: &[Option<ArrayRef>],
82 ) -> VortexResult<()> {
83 vortex_ensure!(
84 slots[MaskedSlots::CHILD].is_some(),
85 "MaskedArray child slot must be present"
86 );
87 let child = slots[MaskedSlots::CHILD]
88 .as_ref()
89 .vortex_expect("validated child slot");
90 vortex_ensure!(child.len() == len, "MaskedArray child length mismatch");
91 vortex_ensure!(
92 child.dtype().as_nullable() == *dtype,
93 "MaskedArray dtype does not match child and validity"
94 );
95 Ok(())
96 }
97
98 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
99 0
100 }
101
102 fn buffer(_array: ArrayView<'_, Self>, _idx: usize) -> BufferHandle {
103 vortex_panic!("MaskedArray has no buffers")
104 }
105
106 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
107 None
108 }
109
110 fn with_buffers(
111 &self,
112 array: ArrayView<'_, Self>,
113 buffers: &[BufferHandle],
114 ) -> VortexResult<ArrayParts<Self>> {
115 with_empty_buffers(self, array, buffers)
116 }
117
118 fn serialize(
119 _array: ArrayView<'_, Self>,
120 _session: &VortexSession,
121 ) -> VortexResult<Option<Vec<u8>>> {
122 Ok(Some(vec![]))
123 }
124
125 fn deserialize(
126 &self,
127 dtype: &DType,
128 len: usize,
129 metadata: &[u8],
130
131 buffers: &[BufferHandle],
132 children: &dyn ArrayChildren,
133 _session: &VortexSession,
134 ) -> VortexResult<ArrayParts<Self>> {
135 if !metadata.is_empty() {
136 vortex_bail!(
137 "MaskedArray expects empty metadata, got {} bytes",
138 metadata.len()
139 );
140 }
141 if !buffers.is_empty() {
142 vortex_bail!("Expected 0 buffer, got {}", buffers.len());
143 }
144
145 vortex_ensure!(
146 children.len() == 1 || children.len() == 2,
147 "`MaskedArray::build` expects 1 or 2 children, got {}",
148 children.len()
149 );
150
151 let child = children.get(0, &dtype.as_nonnullable(), len)?;
152
153 let validity = if children.len() == 2 {
154 let validity = children.get(1, &Validity::DTYPE, len)?;
155 Validity::Array(validity)
156 } else {
157 Validity::from(dtype.nullability())
158 };
159
160 let validity_slot = validity_to_child(&validity, len);
161 let data = MaskedData::try_new(
162 len,
163 child.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?,
164 validity,
165 )?;
166 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data)
167 .with_slots(smallvec![Some(child), validity_slot]))
168 }
169
170 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
171 let array = require_child!(array, array.child(), MaskedSlots::CHILD => AnyCanonical);
172
173 let validity = array.masked_validity();
174
175 if validity.definitely_all_null() {
177 return Ok(ExecutionResult::done(
178 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
179 .into_array(),
180 ));
181 }
182
183 let child = Canonical::from(array.child().as_::<AnyCanonical>());
190 Ok(ExecutionResult::done(
191 mask_validity_canonical(child, validity, ctx)?.into_array(),
192 ))
193 }
194
195 fn reduce_parent(
196 array: ArrayView<'_, Self>,
197 parent: &ArrayRef,
198 child_idx: usize,
199 ) -> VortexResult<Option<ArrayRef>> {
200 PARENT_RULES.evaluate(array, parent, child_idx)
201 }
202 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
203 MaskedSlots::NAMES[idx].to_string()
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use rstest::rstest;
210 use vortex_buffer::ByteBufferMut;
211 use vortex_error::VortexError;
212 use vortex_session::registry::ReadContext;
213
214 use crate::ArrayContext;
215 use crate::Canonical;
216 use crate::IntoArray;
217 use crate::VortexSessionExecute;
218 use crate::array_session;
219 use crate::arrays::Masked;
220 use crate::arrays::MaskedArray;
221 use crate::arrays::PrimitiveArray;
222 use crate::dtype::Nullability;
223 use crate::serde::SerializeOptions;
224 use crate::serde::SerializedArray;
225 use crate::validity::Validity;
226
227 #[rstest]
228 #[case(
229 MaskedArray::try_new(
230 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
231 Validity::AllValid
232 ).unwrap()
233 )]
234 #[case(
235 MaskedArray::try_new(
236 PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
237 Validity::from_iter([true, true, false, true, false])
238 ).unwrap()
239 )]
240 #[case(
241 MaskedArray::try_new(
242 PrimitiveArray::from_iter(0..100).into_array(),
243 Validity::from_iter((0..100).map(|i| i % 3 != 0))
244 ).unwrap()
245 )]
246 fn test_serde_roundtrip(#[case] array: MaskedArray) {
247 let dtype = array.dtype().clone();
248 let len = array.len();
249
250 let ctx = ArrayContext::empty();
251 let serialized = array
252 .clone()
253 .into_array()
254 .serialize(&ctx, &array_session(), &SerializeOptions::default())
255 .unwrap();
256
257 let mut concat = ByteBufferMut::empty();
259 for buf in serialized {
260 concat.extend_from_slice(buf.as_ref());
261 }
262 let concat = concat.freeze();
263
264 let parts = SerializedArray::try_from(concat).unwrap();
265 let decoded = parts
266 .decode(
267 &dtype,
268 len,
269 &ReadContext::new(ctx.to_ids()),
270 &array_session(),
271 )
272 .unwrap();
273
274 assert!(decoded.is::<Masked>());
275 assert_eq!(
276 array.as_ref().display_values().to_string(),
277 decoded.display_values().to_string()
278 );
279 }
280
281 #[test]
287 fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
288 let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
292 assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
293
294 let array = MaskedArray::try_new(child, Validity::AllValid)?;
295 assert_eq!(array.dtype().nullability(), Nullability::Nullable);
296
297 let mut ctx = array_session().create_execution_ctx();
299 let result: Canonical = array.into_array().execute(&mut ctx)?;
300
301 assert_eq!(
302 result.dtype().nullability(),
303 Nullability::Nullable,
304 "MaskedArray execute should produce Nullable dtype"
305 );
306
307 Ok(())
308 }
309}