vortex_array/arrays/masked/vtable/
mod.rs1mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hash;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::EmptyMetadata;
19use crate::IntoArray;
20use crate::Precision;
21use crate::arrays::ConstantArray;
22use crate::arrays::MaskedArray;
23use crate::arrays::masked::compute::rules::PARENT_RULES;
24use crate::arrays::masked::mask_validity_canonical;
25use crate::buffer::BufferHandle;
26use crate::dtype::DType;
27use crate::executor::ExecutionCtx;
28use crate::executor::ExecutionStep;
29use crate::hash::ArrayEq;
30use crate::hash::ArrayHash;
31use crate::scalar::Scalar;
32use crate::serde::ArrayChildren;
33use crate::stats::StatsSetRef;
34use crate::validity::Validity;
35use crate::vtable;
36use crate::vtable::ArrayId;
37use crate::vtable::VTable;
38use crate::vtable::ValidityVTableFromValidityHelper;
39use crate::vtable::validity_nchildren;
40use crate::vtable::validity_to_child;
41vtable!(Masked);
42
43#[derive(Debug)]
44pub struct MaskedVTable;
45
46impl MaskedVTable {
47 pub const ID: ArrayId = ArrayId::new_ref("vortex.masked");
48}
49
50impl VTable for MaskedVTable {
51 type Array = MaskedArray;
52
53 type Metadata = EmptyMetadata;
54 type OperationsVTable = Self;
55 type ValidityVTable = ValidityVTableFromValidityHelper;
56
57 fn id(_array: &Self::Array) -> ArrayId {
58 Self::ID
59 }
60
61 fn len(array: &MaskedArray) -> usize {
62 array.child.len()
63 }
64
65 fn dtype(array: &MaskedArray) -> &DType {
66 &array.dtype
67 }
68
69 fn stats(array: &MaskedArray) -> StatsSetRef<'_> {
70 array.stats.to_ref(array.as_ref())
71 }
72
73 fn array_hash<H: std::hash::Hasher>(array: &MaskedArray, state: &mut H, precision: Precision) {
74 array.child.array_hash(state, precision);
75 array.validity.array_hash(state, precision);
76 array.dtype.hash(state);
77 }
78
79 fn array_eq(array: &MaskedArray, other: &MaskedArray, precision: Precision) -> bool {
80 array.child.array_eq(&other.child, precision)
81 && array.validity.array_eq(&other.validity, precision)
82 && array.dtype == other.dtype
83 }
84
85 fn nbuffers(_array: &Self::Array) -> usize {
86 0
87 }
88
89 fn buffer(_array: &Self::Array, _idx: usize) -> BufferHandle {
90 vortex_panic!("MaskedArray has no buffers")
91 }
92
93 fn buffer_name(_array: &Self::Array, _idx: usize) -> Option<String> {
94 None
95 }
96
97 fn nchildren(array: &Self::Array) -> usize {
98 1 + validity_nchildren(&array.validity)
99 }
100
101 fn child(array: &Self::Array, idx: usize) -> ArrayRef {
102 match idx {
103 0 => array.child.clone(),
104 1 => validity_to_child(&array.validity, array.child.len())
105 .vortex_expect("MaskedArray validity child out of bounds"),
106 _ => vortex_panic!("MaskedArray child index {idx} out of bounds"),
107 }
108 }
109
110 fn child_name(_array: &Self::Array, idx: usize) -> String {
111 match idx {
112 0 => "child".to_string(),
113 1 => "validity".to_string(),
114 _ => vortex_panic!("MaskedArray child_name index {idx} out of bounds"),
115 }
116 }
117
118 fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
119 Ok(EmptyMetadata)
120 }
121
122 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
123 Ok(Some(vec![]))
124 }
125
126 fn deserialize(
127 _bytes: &[u8],
128 _dtype: &DType,
129 _len: usize,
130 _buffers: &[BufferHandle],
131 _session: &VortexSession,
132 ) -> VortexResult<Self::Metadata> {
133 Ok(EmptyMetadata)
134 }
135
136 fn build(
137 dtype: &DType,
138 len: usize,
139 _metadata: &Self::Metadata,
140 buffers: &[BufferHandle],
141 children: &dyn ArrayChildren,
142 ) -> VortexResult<MaskedArray> {
143 if !buffers.is_empty() {
144 vortex_bail!("Expected 0 buffer, got {}", buffers.len());
145 }
146
147 let child = children.get(0, &dtype.as_nonnullable(), len)?;
148
149 let validity = if children.len() == 1 {
150 Validity::from(dtype.nullability())
151 } else if children.len() == 2 {
152 let validity = children.get(1, &Validity::DTYPE, len)?;
153 Validity::Array(validity)
154 } else {
155 vortex_bail!(
156 "`MaskedArray::build` expects 1 or 2 children, got {}",
157 children.len()
158 );
159 };
160
161 MaskedArray::try_new(child, validity)
162 }
163
164 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
165 let validity_mask = array.validity_mask()?;
166
167 if validity_mask.all_false() {
169 return Ok(ExecutionStep::Done(
170 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
171 .into_array(),
172 ));
173 }
174
175 let child = array.child().clone().execute::<Canonical>(ctx)?;
182 Ok(ExecutionStep::Done(
183 mask_validity_canonical(child, &validity_mask, ctx)?.into_array(),
184 ))
185 }
186
187 fn reduce_parent(
188 array: &Self::Array,
189 parent: &ArrayRef,
190 child_idx: usize,
191 ) -> VortexResult<Option<ArrayRef>> {
192 PARENT_RULES.evaluate(array, parent, child_idx)
193 }
194
195 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
196 vortex_ensure!(
197 children.len() == 1 || children.len() == 2,
198 "MaskedArray expects 1 or 2 children, got {}",
199 children.len()
200 );
201
202 let mut iter = children.into_iter();
203 let child = iter
204 .next()
205 .vortex_expect("children length already validated");
206 let validity = if let Some(validity_array) = iter.next() {
207 Validity::Array(validity_array)
208 } else {
209 Validity::from(array.dtype.nullability())
210 };
211
212 let new_array = MaskedArray::try_new(child, validity)?;
213 *array = new_array;
214 Ok(())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use rstest::rstest;
221 use vortex_buffer::ByteBufferMut;
222 use vortex_error::VortexError;
223 use vortex_session::registry::ReadContext;
224
225 use crate::ArrayContext;
226 use crate::Canonical;
227 use crate::IntoArray;
228 use crate::LEGACY_SESSION;
229 use crate::VortexSessionExecute;
230 use crate::arrays::MaskedArray;
231 use crate::arrays::MaskedVTable;
232 use crate::arrays::PrimitiveArray;
233 use crate::dtype::Nullability;
234 use crate::serde::ArrayParts;
235 use crate::serde::SerializeOptions;
236 use crate::validity::Validity;
237
238 #[rstest]
239 #[case(
240 MaskedArray::try_new(
241 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
242 Validity::AllValid
243 ).unwrap()
244 )]
245 #[case(
246 MaskedArray::try_new(
247 PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
248 Validity::from_iter([true, true, false, true, false])
249 ).unwrap()
250 )]
251 #[case(
252 MaskedArray::try_new(
253 PrimitiveArray::from_iter(0..100).into_array(),
254 Validity::from_iter((0..100).map(|i| i % 3 != 0))
255 ).unwrap()
256 )]
257 fn test_serde_roundtrip(#[case] array: MaskedArray) {
258 let dtype = array.dtype().clone();
259 let len = array.len();
260
261 let ctx = ArrayContext::empty();
262 let serialized = array
263 .clone()
264 .into_array()
265 .serialize(&ctx, &SerializeOptions::default())
266 .unwrap();
267
268 let mut concat = ByteBufferMut::empty();
270 for buf in serialized {
271 concat.extend_from_slice(buf.as_ref());
272 }
273 let concat = concat.freeze();
274
275 let parts = ArrayParts::try_from(concat).unwrap();
276 let decoded = parts
277 .decode(
278 &dtype,
279 len,
280 &ReadContext::new(ctx.to_ids()),
281 &LEGACY_SESSION,
282 )
283 .unwrap();
284
285 assert!(decoded.is::<MaskedVTable>());
286 assert_eq!(
287 array.as_ref().display_values().to_string(),
288 decoded.display_values().to_string()
289 );
290 }
291
292 #[test]
298 fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
299 let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
303 assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
304
305 let array = MaskedArray::try_new(child, Validity::AllValid)?;
306 assert_eq!(array.dtype().nullability(), Nullability::Nullable);
307
308 let mut ctx = LEGACY_SESSION.create_execution_ctx();
310 let result: Canonical = array.into_array().execute(&mut ctx)?;
311
312 assert_eq!(
313 result.as_ref().dtype().nullability(),
314 Nullability::Nullable,
315 "MaskedArray execute should produce Nullable dtype"
316 );
317
318 Ok(())
319 }
320}