Skip to main content

vortex_array/arrays/masked/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hasher;
8
9use smallvec::smallvec;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_error::vortex_panic;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use crate::AnyCanonical;
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayParts;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::EqMode;
25use crate::IntoArray;
26use crate::LEGACY_SESSION;
27use crate::VortexSessionExecute;
28use crate::array::Array;
29use crate::array::ArrayId;
30use crate::array::ArrayView;
31use crate::array::VTable;
32use crate::array::validity_to_child;
33use crate::array::with_empty_buffers;
34use crate::arrays::ConstantArray;
35use crate::arrays::masked::MaskedArrayExt;
36use crate::arrays::masked::MaskedArraySlotsExt;
37use crate::arrays::masked::MaskedData;
38use crate::arrays::masked::array::MaskedSlots;
39use crate::arrays::masked::compute::rules::PARENT_RULES;
40use crate::arrays::masked::mask_validity_canonical;
41use crate::buffer::BufferHandle;
42use crate::dtype::DType;
43use crate::executor::ExecutionCtx;
44use crate::executor::ExecutionResult;
45use crate::require_child;
46use crate::scalar::Scalar;
47use crate::serde::ArrayChildren;
48use crate::validity::Validity;
49/// A [`Masked`]-encoded Vortex array.
50pub type MaskedArray = Array<Masked>;
51
52#[derive(Clone, Debug)]
53pub struct Masked;
54
55impl ArrayHash for MaskedData {
56    fn array_hash<H: Hasher>(&self, _state: &mut H, _accuracy: EqMode) {}
57}
58
59impl ArrayEq for MaskedData {
60    fn array_eq(&self, _other: &Self, _accuracy: EqMode) -> bool {
61        true
62    }
63}
64
65impl VTable for Masked {
66    type TypedArrayData = MaskedData;
67
68    type OperationsVTable = Self;
69    type ValidityVTable = Self;
70
71    fn id(&self) -> ArrayId {
72        static ID: CachedId = CachedId::new("vortex.masked");
73        *ID
74    }
75
76    fn validate(
77        &self,
78        _data: &MaskedData,
79        dtype: &DType,
80        len: usize,
81        slots: &[Option<ArrayRef>],
82    ) -> VortexResult<()> {
83        vortex_ensure!(
84            slots[MaskedSlots::CHILD].is_some(),
85            "MaskedArray child slot must be present"
86        );
87        let child = slots[MaskedSlots::CHILD]
88            .as_ref()
89            .vortex_expect("validated child slot");
90        vortex_ensure!(child.len() == len, "MaskedArray child length mismatch");
91        vortex_ensure!(
92            child.dtype().as_nullable() == *dtype,
93            "MaskedArray dtype does not match child and validity"
94        );
95        Ok(())
96    }
97
98    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
99        0
100    }
101
102    fn buffer(_array: ArrayView<'_, Self>, _idx: usize) -> BufferHandle {
103        vortex_panic!("MaskedArray has no buffers")
104    }
105
106    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
107        None
108    }
109
110    fn with_buffers(
111        &self,
112        array: ArrayView<'_, Self>,
113        buffers: &[BufferHandle],
114    ) -> VortexResult<ArrayParts<Self>> {
115        with_empty_buffers(self, array, buffers)
116    }
117
118    fn serialize(
119        _array: ArrayView<'_, Self>,
120        _session: &VortexSession,
121    ) -> VortexResult<Option<Vec<u8>>> {
122        Ok(Some(vec![]))
123    }
124
125    fn deserialize(
126        &self,
127        dtype: &DType,
128        len: usize,
129        metadata: &[u8],
130
131        buffers: &[BufferHandle],
132        children: &dyn ArrayChildren,
133        _session: &VortexSession,
134    ) -> VortexResult<ArrayParts<Self>> {
135        if !metadata.is_empty() {
136            vortex_bail!(
137                "MaskedArray expects empty metadata, got {} bytes",
138                metadata.len()
139            );
140        }
141        if !buffers.is_empty() {
142            vortex_bail!("Expected 0 buffer, got {}", buffers.len());
143        }
144
145        vortex_ensure!(
146            children.len() == 1 || children.len() == 2,
147            "`MaskedArray::build` expects 1 or 2 children, got {}",
148            children.len()
149        );
150
151        let child = children.get(0, &dtype.as_nonnullable(), len)?;
152
153        let validity = if children.len() == 2 {
154            let validity = children.get(1, &Validity::DTYPE, len)?;
155            Validity::Array(validity)
156        } else {
157            Validity::from(dtype.nullability())
158        };
159
160        let validity_slot = validity_to_child(&validity, len);
161        let data = MaskedData::try_new(
162            len,
163            child.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?,
164            validity,
165        )?;
166        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data)
167            .with_slots(smallvec![Some(child), validity_slot]))
168    }
169
170    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
171        let array = require_child!(array, array.child(), MaskedSlots::CHILD => AnyCanonical);
172
173        let validity = array.masked_validity();
174
175        // Fast path: all masked means result is all nulls.
176        if validity.definitely_all_null() {
177            return Ok(ExecutionResult::done(
178                ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
179                    .into_array(),
180            ));
181        }
182
183        // NB: We intentionally do NOT have a fast path for `validity_mask.all_true()`.
184        // `MaskedArray`'s dtype is always `Nullable`, but the child has `NonNullable` `DType` (by
185        // invariant). Simply returning the child's canonical would cause a dtype mismatch.
186        // While we could manually convert the dtype, `mask_validity_canonical` is already O(1) for
187        // `AllTrue` masks (no data copying), so there's no benefit.
188
189        let child = Canonical::from(array.child().as_::<AnyCanonical>());
190        Ok(ExecutionResult::done(
191            mask_validity_canonical(child, validity, ctx)?.into_array(),
192        ))
193    }
194
195    fn reduce_parent(
196        array: ArrayView<'_, Self>,
197        parent: &ArrayRef,
198        child_idx: usize,
199    ) -> VortexResult<Option<ArrayRef>> {
200        PARENT_RULES.evaluate(array, parent, child_idx)
201    }
202    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
203        MaskedSlots::NAMES[idx].to_string()
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use rstest::rstest;
210    use vortex_buffer::ByteBufferMut;
211    use vortex_error::VortexError;
212    use vortex_session::registry::ReadContext;
213
214    use crate::ArrayContext;
215    use crate::Canonical;
216    use crate::IntoArray;
217    use crate::VortexSessionExecute;
218    use crate::array_session;
219    use crate::arrays::Masked;
220    use crate::arrays::MaskedArray;
221    use crate::arrays::PrimitiveArray;
222    use crate::dtype::Nullability;
223    use crate::serde::SerializeOptions;
224    use crate::serde::SerializedArray;
225    use crate::validity::Validity;
226
227    #[rstest]
228    #[case(
229        MaskedArray::try_new(
230            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
231            Validity::AllValid
232        ).unwrap()
233    )]
234    #[case(
235        MaskedArray::try_new(
236            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
237            Validity::from_iter([true, true, false, true, false])
238        ).unwrap()
239    )]
240    #[case(
241        MaskedArray::try_new(
242            PrimitiveArray::from_iter(0..100).into_array(),
243            Validity::from_iter((0..100).map(|i| i % 3 != 0))
244        ).unwrap()
245    )]
246    fn test_serde_roundtrip(#[case] array: MaskedArray) {
247        let dtype = array.dtype().clone();
248        let len = array.len();
249
250        let ctx = ArrayContext::empty();
251        let serialized = array
252            .clone()
253            .into_array()
254            .serialize(&ctx, &array_session(), &SerializeOptions::default())
255            .unwrap();
256
257        // Concat into a single buffer.
258        let mut concat = ByteBufferMut::empty();
259        for buf in serialized {
260            concat.extend_from_slice(buf.as_ref());
261        }
262        let concat = concat.freeze();
263
264        let parts = SerializedArray::try_from(concat).unwrap();
265        let decoded = parts
266            .decode(
267                &dtype,
268                len,
269                &ReadContext::new(ctx.to_ids()),
270                &array_session(),
271            )
272            .unwrap();
273
274        assert!(decoded.is::<Masked>());
275        assert_eq!(
276            array.as_ref().display_values().to_string(),
277            decoded.display_values().to_string()
278        );
279    }
280
281    /// Regression test for issue #5989: execute_fast_path returns child with wrong dtype.
282    ///
283    /// When MaskedArray's validity mask is all true, returning the child's canonical form
284    /// directly would cause a dtype mismatch because the child has NonNullable dtype while
285    /// MaskedArray always has Nullable dtype.
286    #[test]
287    fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
288        // Create a MaskedArray with AllValid validity.
289
290        // Child has NonNullable dtype, but MaskedArray's dtype is Nullable.
291        let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
292        assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
293
294        let array = MaskedArray::try_new(child, Validity::AllValid)?;
295        assert_eq!(array.dtype().nullability(), Nullability::Nullable);
296
297        // Execute the array. This should produce a Canonical with Nullable dtype.
298        let mut ctx = array_session().create_execution_ctx();
299        let result: Canonical = array.into_array().execute(&mut ctx)?;
300
301        assert_eq!(
302            result.dtype().nullability(),
303            Nullability::Nullable,
304            "MaskedArray execute should produce Nullable dtype"
305        );
306
307        Ok(())
308    }
309}