vortex_array/arrays/masked/vtable/
mod.rs1mod array;
5mod canonical;
6mod operations;
7mod validity;
8
9use kernel::PARENT_KERNELS;
10use vortex_dtype::DType;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16
17use crate::ArrayBufferVisitor;
18use crate::ArrayChildVisitor;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::EmptyMetadata;
22use crate::IntoArray;
23use crate::arrays::ConstantArray;
24use crate::arrays::masked::MaskedArray;
25use crate::arrays::masked::compute::rules::PARENT_RULES;
26use crate::arrays::masked::mask_validity_canonical;
27use crate::buffer::BufferHandle;
28use crate::executor::ExecutionCtx;
29use crate::scalar::Scalar;
30use crate::serde::ArrayChildren;
31use crate::validity::Validity;
32use crate::vtable;
33use crate::vtable::ArrayId;
34use crate::vtable::VTable;
35use crate::vtable::ValidityVTableFromValidityHelper;
36use crate::vtable::VisitorVTable;
37use crate::vtable::validity_nchildren;
38use crate::vtable::validity_to_child;
39
40mod kernel;
41
42vtable!(Masked);
43
44#[derive(Debug)]
45pub struct MaskedVTable;
46
47impl MaskedVTable {
48 pub const ID: ArrayId = ArrayId::new_ref("vortex.masked");
49}
50
51impl VisitorVTable<MaskedVTable> for MaskedVTable {
52 fn visit_buffers(_array: &MaskedArray, _visitor: &mut dyn ArrayBufferVisitor) {}
53
54 fn visit_children(array: &MaskedArray, visitor: &mut dyn ArrayChildVisitor) {
55 visitor.visit_child("child", &array.child);
56 visitor.visit_validity(&array.validity, array.child.len());
57 }
58
59 fn nchildren(array: &MaskedArray) -> usize {
60 1 + validity_nchildren(&array.validity)
61 }
62
63 fn nth_child(array: &MaskedArray, idx: usize) -> Option<ArrayRef> {
64 match idx {
65 0 => Some(array.child.clone()),
66 1 => validity_to_child(&array.validity, array.child.len()),
67 _ => None,
68 }
69 }
70}
71
72impl VTable for MaskedVTable {
73 type Array = MaskedArray;
74
75 type Metadata = EmptyMetadata;
76
77 type ArrayVTable = Self;
78 type OperationsVTable = Self;
79 type ValidityVTable = ValidityVTableFromValidityHelper;
80 type VisitorVTable = Self;
81
82 fn id(_array: &Self::Array) -> ArrayId {
83 Self::ID
84 }
85
86 fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
87 Ok(EmptyMetadata)
88 }
89
90 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
91 Ok(Some(vec![]))
92 }
93
94 fn deserialize(
95 _bytes: &[u8],
96 _dtype: &DType,
97 _len: usize,
98 _buffers: &[BufferHandle],
99 _session: &VortexSession,
100 ) -> VortexResult<Self::Metadata> {
101 Ok(EmptyMetadata)
102 }
103
104 fn build(
105 dtype: &DType,
106 len: usize,
107 _metadata: &Self::Metadata,
108 buffers: &[BufferHandle],
109 children: &dyn ArrayChildren,
110 ) -> VortexResult<MaskedArray> {
111 if !buffers.is_empty() {
112 vortex_bail!("Expected 0 buffer, got {}", buffers.len());
113 }
114
115 let child = children.get(0, &dtype.as_nonnullable(), len)?;
116
117 let validity = if children.len() == 1 {
118 Validity::from(dtype.nullability())
119 } else if children.len() == 2 {
120 let validity = children.get(1, &Validity::DTYPE, len)?;
121 Validity::Array(validity)
122 } else {
123 vortex_bail!(
124 "`MaskedArray::build` expects 1 or 2 children, got {}",
125 children.len()
126 );
127 };
128
129 MaskedArray::try_new(child, validity)
130 }
131
132 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
133 let validity_mask = array.validity_mask()?;
134
135 if validity_mask.all_false() {
137 return Ok(
138 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
139 .into_array(),
140 );
141 }
142
143 let child = array.child().clone().execute::<Canonical>(ctx)?;
150 Ok(mask_validity_canonical(child, &validity_mask, ctx)?.into_array())
151 }
152
153 fn reduce_parent(
154 array: &Self::Array,
155 parent: &ArrayRef,
156 child_idx: usize,
157 ) -> VortexResult<Option<ArrayRef>> {
158 PARENT_RULES.evaluate(array, parent, child_idx)
159 }
160
161 fn execute_parent(
162 array: &Self::Array,
163 parent: &ArrayRef,
164 child_idx: usize,
165 ctx: &mut ExecutionCtx,
166 ) -> VortexResult<Option<ArrayRef>> {
167 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
168 }
169
170 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
171 vortex_ensure!(
172 children.len() == 1 || children.len() == 2,
173 "MaskedArray expects 1 or 2 children, got {}",
174 children.len()
175 );
176
177 let mut iter = children.into_iter();
178 let child = iter
179 .next()
180 .vortex_expect("children length already validated");
181 let validity = if let Some(validity_array) = iter.next() {
182 Validity::Array(validity_array)
183 } else {
184 Validity::from(array.dtype.nullability())
185 };
186
187 let new_array = MaskedArray::try_new(child, validity)?;
188 *array = new_array;
189 Ok(())
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use rstest::rstest;
196 use vortex_buffer::ByteBufferMut;
197 use vortex_dtype::Nullability;
198 use vortex_error::VortexError;
199
200 use crate::ArrayContext;
201 use crate::Canonical;
202 use crate::IntoArray;
203 use crate::LEGACY_SESSION;
204 use crate::VortexSessionExecute;
205 use crate::arrays::MaskedArray;
206 use crate::arrays::MaskedVTable;
207 use crate::arrays::PrimitiveArray;
208 use crate::serde::ArrayParts;
209 use crate::serde::SerializeOptions;
210 use crate::validity::Validity;
211
212 #[rstest]
213 #[case(
214 MaskedArray::try_new(
215 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
216 Validity::AllValid
217 ).unwrap()
218 )]
219 #[case(
220 MaskedArray::try_new(
221 PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
222 Validity::from_iter([true, true, false, true, false])
223 ).unwrap()
224 )]
225 #[case(
226 MaskedArray::try_new(
227 PrimitiveArray::from_iter(0..100).into_array(),
228 Validity::from_iter((0..100).map(|i| i % 3 != 0))
229 ).unwrap()
230 )]
231 fn test_serde_roundtrip(#[case] array: MaskedArray) {
232 let dtype = array.dtype().clone();
233 let len = array.len();
234
235 let ctx = ArrayContext::empty();
236 let serialized = array
237 .to_array()
238 .serialize(&ctx, &SerializeOptions::default())
239 .unwrap();
240
241 let mut concat = ByteBufferMut::empty();
243 for buf in serialized {
244 concat.extend_from_slice(buf.as_ref());
245 }
246 let concat = concat.freeze();
247
248 let parts = ArrayParts::try_from(concat).unwrap();
249 let decoded = parts.decode(&dtype, len, &ctx, &LEGACY_SESSION).unwrap();
250
251 assert!(decoded.is::<MaskedVTable>());
252 assert_eq!(
253 array.as_ref().display_values().to_string(),
254 decoded.display_values().to_string()
255 );
256 }
257
258 #[test]
264 fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
265 let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
269 assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
270
271 let array = MaskedArray::try_new(child, Validity::AllValid)?;
272 assert_eq!(array.dtype().nullability(), Nullability::Nullable);
273
274 let mut ctx = LEGACY_SESSION.create_execution_ctx();
276 let result: Canonical = array.into_array().execute(&mut ctx)?;
277
278 assert_eq!(
279 result.as_ref().dtype().nullability(),
280 Nullability::Nullable,
281 "MaskedArray execute should produce Nullable dtype"
282 );
283
284 Ok(())
285 }
286}