vortex_array/arrays/masked/vtable/
mod.rs1mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hash;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::EmptyMetadata;
19use crate::IntoArray;
20use crate::Precision;
21use crate::arrays::ConstantArray;
22use crate::arrays::masked::MaskedArray;
23use crate::arrays::masked::compute::rules::PARENT_RULES;
24use crate::arrays::masked::mask_validity_canonical;
25use crate::buffer::BufferHandle;
26use crate::dtype::DType;
27use crate::executor::ExecutionCtx;
28use crate::hash::ArrayEq;
29use crate::hash::ArrayHash;
30use crate::scalar::Scalar;
31use crate::serde::ArrayChildren;
32use crate::stats::StatsSetRef;
33use crate::validity::Validity;
34use crate::vtable;
35use crate::vtable::ArrayId;
36use crate::vtable::VTable;
37use crate::vtable::ValidityVTableFromValidityHelper;
38use crate::vtable::validity_nchildren;
39use crate::vtable::validity_to_child;
40vtable!(Masked);
41
42#[derive(Debug)]
43pub struct MaskedVTable;
44
45impl MaskedVTable {
46 pub const ID: ArrayId = ArrayId::new_ref("vortex.masked");
47}
48
49impl VTable for MaskedVTable {
50 type Array = MaskedArray;
51
52 type Metadata = EmptyMetadata;
53 type OperationsVTable = Self;
54 type ValidityVTable = ValidityVTableFromValidityHelper;
55
56 fn id(_array: &Self::Array) -> ArrayId {
57 Self::ID
58 }
59
60 fn len(array: &MaskedArray) -> usize {
61 array.child.len()
62 }
63
64 fn dtype(array: &MaskedArray) -> &DType {
65 &array.dtype
66 }
67
68 fn stats(array: &MaskedArray) -> StatsSetRef<'_> {
69 array.stats.to_ref(array.as_ref())
70 }
71
72 fn array_hash<H: std::hash::Hasher>(array: &MaskedArray, state: &mut H, precision: Precision) {
73 array.child.array_hash(state, precision);
74 array.validity.array_hash(state, precision);
75 array.dtype.hash(state);
76 }
77
78 fn array_eq(array: &MaskedArray, other: &MaskedArray, precision: Precision) -> bool {
79 array.child.array_eq(&other.child, precision)
80 && array.validity.array_eq(&other.validity, precision)
81 && array.dtype == other.dtype
82 }
83
84 fn nbuffers(_array: &Self::Array) -> usize {
85 0
86 }
87
88 fn buffer(_array: &Self::Array, _idx: usize) -> BufferHandle {
89 vortex_panic!("MaskedArray has no buffers")
90 }
91
92 fn buffer_name(_array: &Self::Array, _idx: usize) -> Option<String> {
93 None
94 }
95
96 fn nchildren(array: &Self::Array) -> usize {
97 1 + validity_nchildren(&array.validity)
98 }
99
100 fn child(array: &Self::Array, idx: usize) -> ArrayRef {
101 match idx {
102 0 => array.child.clone(),
103 1 => validity_to_child(&array.validity, array.child.len())
104 .vortex_expect("MaskedArray validity child out of bounds"),
105 _ => vortex_panic!("MaskedArray child index {idx} out of bounds"),
106 }
107 }
108
109 fn child_name(_array: &Self::Array, idx: usize) -> String {
110 match idx {
111 0 => "child".to_string(),
112 1 => "validity".to_string(),
113 _ => vortex_panic!("MaskedArray child_name index {idx} out of bounds"),
114 }
115 }
116
117 fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
118 Ok(EmptyMetadata)
119 }
120
121 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
122 Ok(Some(vec![]))
123 }
124
125 fn deserialize(
126 _bytes: &[u8],
127 _dtype: &DType,
128 _len: usize,
129 _buffers: &[BufferHandle],
130 _session: &VortexSession,
131 ) -> VortexResult<Self::Metadata> {
132 Ok(EmptyMetadata)
133 }
134
135 fn build(
136 dtype: &DType,
137 len: usize,
138 _metadata: &Self::Metadata,
139 buffers: &[BufferHandle],
140 children: &dyn ArrayChildren,
141 ) -> VortexResult<MaskedArray> {
142 if !buffers.is_empty() {
143 vortex_bail!("Expected 0 buffer, got {}", buffers.len());
144 }
145
146 let child = children.get(0, &dtype.as_nonnullable(), len)?;
147
148 let validity = if children.len() == 1 {
149 Validity::from(dtype.nullability())
150 } else if children.len() == 2 {
151 let validity = children.get(1, &Validity::DTYPE, len)?;
152 Validity::Array(validity)
153 } else {
154 vortex_bail!(
155 "`MaskedArray::build` expects 1 or 2 children, got {}",
156 children.len()
157 );
158 };
159
160 MaskedArray::try_new(child, validity)
161 }
162
163 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
164 let validity_mask = array.validity_mask()?;
165
166 if validity_mask.all_false() {
168 return Ok(
169 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
170 .into_array(),
171 );
172 }
173
174 let child = array.child().clone().execute::<Canonical>(ctx)?;
181 Ok(mask_validity_canonical(child, &validity_mask, ctx)?.into_array())
182 }
183
184 fn reduce_parent(
185 array: &Self::Array,
186 parent: &ArrayRef,
187 child_idx: usize,
188 ) -> VortexResult<Option<ArrayRef>> {
189 PARENT_RULES.evaluate(array, parent, child_idx)
190 }
191
192 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
193 vortex_ensure!(
194 children.len() == 1 || children.len() == 2,
195 "MaskedArray expects 1 or 2 children, got {}",
196 children.len()
197 );
198
199 let mut iter = children.into_iter();
200 let child = iter
201 .next()
202 .vortex_expect("children length already validated");
203 let validity = if let Some(validity_array) = iter.next() {
204 Validity::Array(validity_array)
205 } else {
206 Validity::from(array.dtype.nullability())
207 };
208
209 let new_array = MaskedArray::try_new(child, validity)?;
210 *array = new_array;
211 Ok(())
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use rstest::rstest;
218 use vortex_buffer::ByteBufferMut;
219 use vortex_error::VortexError;
220
221 use crate::ArrayContext;
222 use crate::Canonical;
223 use crate::IntoArray;
224 use crate::LEGACY_SESSION;
225 use crate::VortexSessionExecute;
226 use crate::arrays::MaskedArray;
227 use crate::arrays::MaskedVTable;
228 use crate::arrays::PrimitiveArray;
229 use crate::dtype::Nullability;
230 use crate::serde::ArrayParts;
231 use crate::serde::SerializeOptions;
232 use crate::validity::Validity;
233
234 #[rstest]
235 #[case(
236 MaskedArray::try_new(
237 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
238 Validity::AllValid
239 ).unwrap()
240 )]
241 #[case(
242 MaskedArray::try_new(
243 PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
244 Validity::from_iter([true, true, false, true, false])
245 ).unwrap()
246 )]
247 #[case(
248 MaskedArray::try_new(
249 PrimitiveArray::from_iter(0..100).into_array(),
250 Validity::from_iter((0..100).map(|i| i % 3 != 0))
251 ).unwrap()
252 )]
253 fn test_serde_roundtrip(#[case] array: MaskedArray) {
254 let dtype = array.dtype().clone();
255 let len = array.len();
256
257 let ctx = ArrayContext::empty();
258 let serialized = array
259 .to_array()
260 .serialize(&ctx, &SerializeOptions::default())
261 .unwrap();
262
263 let mut concat = ByteBufferMut::empty();
265 for buf in serialized {
266 concat.extend_from_slice(buf.as_ref());
267 }
268 let concat = concat.freeze();
269
270 let parts = ArrayParts::try_from(concat).unwrap();
271 let decoded = parts.decode(&dtype, len, &ctx, &LEGACY_SESSION).unwrap();
272
273 assert!(decoded.is::<MaskedVTable>());
274 assert_eq!(
275 array.as_ref().display_values().to_string(),
276 decoded.display_values().to_string()
277 );
278 }
279
280 #[test]
286 fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
287 let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
291 assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
292
293 let array = MaskedArray::try_new(child, Validity::AllValid)?;
294 assert_eq!(array.dtype().nullability(), Nullability::Nullable);
295
296 let mut ctx = LEGACY_SESSION.create_execution_ctx();
298 let result: Canonical = array.into_array().execute(&mut ctx)?;
299
300 assert_eq!(
301 result.as_ref().dtype().nullability(),
302 Nullability::Nullable,
303 "MaskedArray execute should produce Nullable dtype"
304 );
305
306 Ok(())
307 }
308}