vortex_array/arrays/masked/vtable/
mod.rs1mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hasher;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::AnyCanonical;
18use crate::ArrayEq;
19use crate::ArrayHash;
20use crate::ArrayRef;
21use crate::Canonical;
22use crate::IntoArray;
23use crate::LEGACY_SESSION;
24use crate::Precision;
25use crate::VortexSessionExecute;
26use crate::array::Array;
27use crate::array::ArrayId;
28use crate::array::ArrayView;
29use crate::array::VTable;
30use crate::array::validity_to_child;
31use crate::arrays::ConstantArray;
32use crate::arrays::masked::MaskedArrayExt;
33use crate::arrays::masked::MaskedArraySlotsExt;
34use crate::arrays::masked::MaskedData;
35use crate::arrays::masked::array::MaskedSlots;
36use crate::arrays::masked::compute::rules::PARENT_RULES;
37use crate::arrays::masked::mask_validity_canonical;
38use crate::buffer::BufferHandle;
39use crate::dtype::DType;
40use crate::executor::ExecutionCtx;
41use crate::executor::ExecutionResult;
42use crate::require_child;
43use crate::scalar::Scalar;
44use crate::serde::ArrayChildren;
45use crate::validity::Validity;
46pub type MaskedArray = Array<Masked>;
48
49#[derive(Clone, Debug)]
50pub struct Masked;
51
52impl ArrayHash for MaskedData {
53 fn array_hash<H: Hasher>(&self, _state: &mut H, _precision: Precision) {}
54}
55
56impl ArrayEq for MaskedData {
57 fn array_eq(&self, _other: &Self, _precision: Precision) -> bool {
58 true
59 }
60}
61
62impl VTable for Masked {
63 type ArrayData = MaskedData;
64
65 type OperationsVTable = Self;
66 type ValidityVTable = Self;
67
68 fn id(&self) -> ArrayId {
69 static ID: CachedId = CachedId::new("vortex.masked");
70 *ID
71 }
72
73 fn validate(
74 &self,
75 _data: &MaskedData,
76 dtype: &DType,
77 len: usize,
78 slots: &[Option<ArrayRef>],
79 ) -> VortexResult<()> {
80 vortex_ensure!(
81 slots[MaskedSlots::CHILD].is_some(),
82 "MaskedArray child slot must be present"
83 );
84 let child = slots[MaskedSlots::CHILD]
85 .as_ref()
86 .vortex_expect("validated child slot");
87 vortex_ensure!(child.len() == len, "MaskedArray child length mismatch");
88 vortex_ensure!(
89 child.dtype().as_nullable() == *dtype,
90 "MaskedArray dtype does not match child and validity"
91 );
92 Ok(())
93 }
94
95 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
96 0
97 }
98
99 fn buffer(_array: ArrayView<'_, Self>, _idx: usize) -> BufferHandle {
100 vortex_panic!("MaskedArray has no buffers")
101 }
102
103 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
104 None
105 }
106
107 fn serialize(
108 _array: ArrayView<'_, Self>,
109 _session: &VortexSession,
110 ) -> VortexResult<Option<Vec<u8>>> {
111 Ok(Some(vec![]))
112 }
113
114 fn deserialize(
115 &self,
116 dtype: &DType,
117 len: usize,
118 metadata: &[u8],
119
120 buffers: &[BufferHandle],
121 children: &dyn ArrayChildren,
122 _session: &VortexSession,
123 ) -> VortexResult<crate::array::ArrayParts<Self>> {
124 if !metadata.is_empty() {
125 vortex_bail!(
126 "MaskedArray expects empty metadata, got {} bytes",
127 metadata.len()
128 );
129 }
130 if !buffers.is_empty() {
131 vortex_bail!("Expected 0 buffer, got {}", buffers.len());
132 }
133
134 vortex_ensure!(
135 children.len() == 1 || children.len() == 2,
136 "`MaskedArray::build` expects 1 or 2 children, got {}",
137 children.len()
138 );
139
140 let child = children.get(0, &dtype.as_nonnullable(), len)?;
141
142 let validity = if children.len() == 2 {
143 let validity = children.get(1, &Validity::DTYPE, len)?;
144 Validity::Array(validity)
145 } else {
146 Validity::from(dtype.nullability())
147 };
148
149 let validity_slot = validity_to_child(&validity, len);
150 let data = MaskedData::try_new(
151 len,
152 child.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?,
153 validity,
154 )?;
155 Ok(
156 crate::array::ArrayParts::new(self.clone(), dtype.clone(), len, data)
157 .with_slots(vec![Some(child), validity_slot]),
158 )
159 }
160
161 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
162 let array = require_child!(array, array.child(), MaskedSlots::CHILD => AnyCanonical);
163
164 let validity = array.masked_validity();
165
166 if matches!(validity, Validity::AllInvalid) {
168 return Ok(ExecutionResult::done(
169 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
170 .into_array(),
171 ));
172 }
173
174 let child = Canonical::from(array.child().as_::<AnyCanonical>());
181 Ok(ExecutionResult::done(
182 mask_validity_canonical(child, validity, ctx)?.into_array(),
183 ))
184 }
185
186 fn reduce_parent(
187 array: ArrayView<'_, Self>,
188 parent: &ArrayRef,
189 child_idx: usize,
190 ) -> VortexResult<Option<ArrayRef>> {
191 PARENT_RULES.evaluate(array, parent, child_idx)
192 }
193 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
194 MaskedSlots::NAMES[idx].to_string()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use rstest::rstest;
201 use vortex_buffer::ByteBufferMut;
202 use vortex_error::VortexError;
203 use vortex_session::registry::ReadContext;
204
205 use crate::ArrayContext;
206 use crate::Canonical;
207 use crate::IntoArray;
208 use crate::LEGACY_SESSION;
209 use crate::VortexSessionExecute;
210 use crate::arrays::Masked;
211 use crate::arrays::MaskedArray;
212 use crate::arrays::PrimitiveArray;
213 use crate::dtype::Nullability;
214 use crate::serde::SerializeOptions;
215 use crate::serde::SerializedArray;
216 use crate::validity::Validity;
217
218 #[rstest]
219 #[case(
220 MaskedArray::try_new(
221 PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
222 Validity::AllValid
223 ).unwrap()
224 )]
225 #[case(
226 MaskedArray::try_new(
227 PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
228 Validity::from_iter([true, true, false, true, false])
229 ).unwrap()
230 )]
231 #[case(
232 MaskedArray::try_new(
233 PrimitiveArray::from_iter(0..100).into_array(),
234 Validity::from_iter((0..100).map(|i| i % 3 != 0))
235 ).unwrap()
236 )]
237 fn test_serde_roundtrip(#[case] array: MaskedArray) {
238 let dtype = array.dtype().clone();
239 let len = array.len();
240
241 let ctx = ArrayContext::empty();
242 let serialized = array
243 .clone()
244 .into_array()
245 .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
246 .unwrap();
247
248 let mut concat = ByteBufferMut::empty();
250 for buf in serialized {
251 concat.extend_from_slice(buf.as_ref());
252 }
253 let concat = concat.freeze();
254
255 let parts = SerializedArray::try_from(concat).unwrap();
256 let decoded = parts
257 .decode(
258 &dtype,
259 len,
260 &ReadContext::new(ctx.to_ids()),
261 &LEGACY_SESSION,
262 )
263 .unwrap();
264
265 assert!(decoded.is::<Masked>());
266 assert_eq!(
267 array.as_ref().display_values().to_string(),
268 decoded.display_values().to_string()
269 );
270 }
271
272 #[test]
278 fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
279 let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
283 assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
284
285 let array = MaskedArray::try_new(child, Validity::AllValid)?;
286 assert_eq!(array.dtype().nullability(), Nullability::Nullable);
287
288 let mut ctx = LEGACY_SESSION.create_execution_ctx();
290 let result: Canonical = array.into_array().execute(&mut ctx)?;
291
292 assert_eq!(
293 result.dtype().nullability(),
294 Nullability::Nullable,
295 "MaskedArray execute should produce Nullable dtype"
296 );
297
298 Ok(())
299 }
300}