Skip to main content

vortex_array/arrays/masked/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hasher;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::AnyCanonical;
18use crate::ArrayEq;
19use crate::ArrayHash;
20use crate::ArrayRef;
21use crate::Canonical;
22use crate::IntoArray;
23use crate::LEGACY_SESSION;
24use crate::Precision;
25use crate::VortexSessionExecute;
26use crate::array::Array;
27use crate::array::ArrayId;
28use crate::array::ArrayView;
29use crate::array::VTable;
30use crate::array::validity_to_child;
31use crate::arrays::ConstantArray;
32use crate::arrays::masked::MaskedArrayExt;
33use crate::arrays::masked::MaskedData;
34use crate::arrays::masked::array::CHILD_SLOT;
35use crate::arrays::masked::array::SLOT_NAMES;
36use crate::arrays::masked::array::VALIDITY_SLOT;
37use crate::arrays::masked::compute::rules::PARENT_RULES;
38use crate::arrays::masked::mask_validity_canonical;
39use crate::buffer::BufferHandle;
40use crate::dtype::DType;
41use crate::executor::ExecutionCtx;
42use crate::executor::ExecutionResult;
43use crate::require_child;
44use crate::require_validity;
45use crate::scalar::Scalar;
46use crate::serde::ArrayChildren;
47use crate::validity::Validity;
48/// A [`Masked`]-encoded Vortex array.
49pub type MaskedArray = Array<Masked>;
50
51#[derive(Clone, Debug)]
52pub struct Masked;
53
54impl ArrayHash for MaskedData {
55    fn array_hash<H: Hasher>(&self, _state: &mut H, _precision: Precision) {}
56}
57
58impl ArrayEq for MaskedData {
59    fn array_eq(&self, _other: &Self, _precision: Precision) -> bool {
60        true
61    }
62}
63
64impl VTable for Masked {
65    type ArrayData = MaskedData;
66
67    type OperationsVTable = Self;
68    type ValidityVTable = Self;
69
70    fn id(&self) -> ArrayId {
71        static ID: CachedId = CachedId::new("vortex.masked");
72        *ID
73    }
74
75    fn validate(
76        &self,
77        _data: &MaskedData,
78        dtype: &DType,
79        len: usize,
80        slots: &[Option<ArrayRef>],
81    ) -> VortexResult<()> {
82        vortex_ensure!(
83            slots[CHILD_SLOT].is_some(),
84            "MaskedArray child slot must be present"
85        );
86        let child = slots[CHILD_SLOT]
87            .as_ref()
88            .vortex_expect("validated child slot");
89        vortex_ensure!(child.len() == len, "MaskedArray child length mismatch");
90        vortex_ensure!(
91            child.dtype().as_nullable() == *dtype,
92            "MaskedArray dtype does not match child and validity"
93        );
94        Ok(())
95    }
96
97    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
98        0
99    }
100
101    fn buffer(_array: ArrayView<'_, Self>, _idx: usize) -> BufferHandle {
102        vortex_panic!("MaskedArray has no buffers")
103    }
104
105    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
106        None
107    }
108
109    fn serialize(
110        _array: ArrayView<'_, Self>,
111        _session: &VortexSession,
112    ) -> VortexResult<Option<Vec<u8>>> {
113        Ok(Some(vec![]))
114    }
115
116    fn deserialize(
117        &self,
118        dtype: &DType,
119        len: usize,
120        metadata: &[u8],
121
122        buffers: &[BufferHandle],
123        children: &dyn ArrayChildren,
124        _session: &VortexSession,
125    ) -> VortexResult<crate::array::ArrayParts<Self>> {
126        if !metadata.is_empty() {
127            vortex_bail!(
128                "MaskedArray expects empty metadata, got {} bytes",
129                metadata.len()
130            );
131        }
132        if !buffers.is_empty() {
133            vortex_bail!("Expected 0 buffer, got {}", buffers.len());
134        }
135
136        vortex_ensure!(
137            children.len() == 1 || children.len() == 2,
138            "`MaskedArray::build` expects 1 or 2 children, got {}",
139            children.len()
140        );
141
142        let child = children.get(0, &dtype.as_nonnullable(), len)?;
143
144        let validity = if children.len() == 2 {
145            let validity = children.get(1, &Validity::DTYPE, len)?;
146            Validity::Array(validity)
147        } else {
148            Validity::from(dtype.nullability())
149        };
150
151        let validity_slot = validity_to_child(&validity, len);
152        let data = MaskedData::try_new(
153            len,
154            child.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?,
155            validity,
156        )?;
157        Ok(
158            crate::array::ArrayParts::new(self.clone(), dtype.clone(), len, data)
159                .with_slots(vec![Some(child), validity_slot]),
160        )
161    }
162
163    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
164        let validity_mask = array.masked_validity().to_mask(array.len(), ctx)?;
165
166        // Fast path: all masked means result is all nulls.
167        if validity_mask.all_false() {
168            return Ok(ExecutionResult::done(
169                ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
170                    .into_array(),
171            ));
172        }
173
174        // NB: We intentionally do NOT have a fast path for `validity_mask.all_true()`.
175        // `MaskedArray`'s dtype is always `Nullable`, but the child has `NonNullable` `DType` (by
176        // invariant). Simply returning the child's canonical would cause a dtype mismatch.
177        // While we could manually convert the dtype, `mask_validity_canonical` is already O(1) for
178        // `AllTrue` masks (no data copying), so there's no benefit.
179
180        let array = require_child!(array, array.child(), CHILD_SLOT => AnyCanonical);
181        require_validity!(array, VALIDITY_SLOT);
182
183        let child = Canonical::from(array.child().as_::<AnyCanonical>());
184        Ok(ExecutionResult::done(
185            mask_validity_canonical(child, &validity_mask, ctx)?.into_array(),
186        ))
187    }
188
189    fn reduce_parent(
190        array: ArrayView<'_, Self>,
191        parent: &ArrayRef,
192        child_idx: usize,
193    ) -> VortexResult<Option<ArrayRef>> {
194        PARENT_RULES.evaluate(array, parent, child_idx)
195    }
196    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
197        SLOT_NAMES[idx].to_string()
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use rstest::rstest;
204    use vortex_buffer::ByteBufferMut;
205    use vortex_error::VortexError;
206    use vortex_session::registry::ReadContext;
207
208    use crate::ArrayContext;
209    use crate::Canonical;
210    use crate::IntoArray;
211    use crate::LEGACY_SESSION;
212    use crate::VortexSessionExecute;
213    use crate::arrays::Masked;
214    use crate::arrays::MaskedArray;
215    use crate::arrays::PrimitiveArray;
216    use crate::dtype::Nullability;
217    use crate::serde::SerializeOptions;
218    use crate::serde::SerializedArray;
219    use crate::validity::Validity;
220
221    #[rstest]
222    #[case(
223        MaskedArray::try_new(
224            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
225            Validity::AllValid
226        ).unwrap()
227    )]
228    #[case(
229        MaskedArray::try_new(
230            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
231            Validity::from_iter([true, true, false, true, false])
232        ).unwrap()
233    )]
234    #[case(
235        MaskedArray::try_new(
236            PrimitiveArray::from_iter(0..100).into_array(),
237            Validity::from_iter((0..100).map(|i| i % 3 != 0))
238        ).unwrap()
239    )]
240    fn test_serde_roundtrip(#[case] array: MaskedArray) {
241        let dtype = array.dtype().clone();
242        let len = array.len();
243
244        let ctx = ArrayContext::empty();
245        let serialized = array
246            .clone()
247            .into_array()
248            .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
249            .unwrap();
250
251        // Concat into a single buffer.
252        let mut concat = ByteBufferMut::empty();
253        for buf in serialized {
254            concat.extend_from_slice(buf.as_ref());
255        }
256        let concat = concat.freeze();
257
258        let parts = SerializedArray::try_from(concat).unwrap();
259        let decoded = parts
260            .decode(
261                &dtype,
262                len,
263                &ReadContext::new(ctx.to_ids()),
264                &LEGACY_SESSION,
265            )
266            .unwrap();
267
268        assert!(decoded.is::<Masked>());
269        assert_eq!(
270            array.as_ref().display_values().to_string(),
271            decoded.display_values().to_string()
272        );
273    }
274
275    /// Regression test for issue #5989: execute_fast_path returns child with wrong dtype.
276    ///
277    /// When MaskedArray's validity mask is all true, returning the child's canonical form
278    /// directly would cause a dtype mismatch because the child has NonNullable dtype while
279    /// MaskedArray always has Nullable dtype.
280    #[test]
281    fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
282        // Create a MaskedArray with AllValid validity.
283
284        // Child has NonNullable dtype, but MaskedArray's dtype is Nullable.
285        let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
286        assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
287
288        let array = MaskedArray::try_new(child, Validity::AllValid)?;
289        assert_eq!(array.dtype().nullability(), Nullability::Nullable);
290
291        // Execute the array. This should produce a Canonical with Nullable dtype.
292        let mut ctx = LEGACY_SESSION.create_execution_ctx();
293        let result: Canonical = array.into_array().execute(&mut ctx)?;
294
295        assert_eq!(
296            result.dtype().nullability(),
297            Nullability::Nullable,
298            "MaskedArray execute should produce Nullable dtype"
299        );
300
301        Ok(())
302    }
303}