Skip to main content

vortex_array/arrays/masked/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hash;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::EmptyMetadata;
19use crate::IntoArray;
20use crate::Precision;
21use crate::arrays::ConstantArray;
22use crate::arrays::MaskedArray;
23use crate::arrays::masked::compute::rules::PARENT_RULES;
24use crate::arrays::masked::mask_validity_canonical;
25use crate::buffer::BufferHandle;
26use crate::dtype::DType;
27use crate::executor::ExecutionCtx;
28use crate::executor::ExecutionStep;
29use crate::hash::ArrayEq;
30use crate::hash::ArrayHash;
31use crate::scalar::Scalar;
32use crate::serde::ArrayChildren;
33use crate::stats::StatsSetRef;
34use crate::validity::Validity;
35use crate::vtable;
36use crate::vtable::ArrayId;
37use crate::vtable::VTable;
38use crate::vtable::ValidityVTableFromValidityHelper;
39use crate::vtable::validity_nchildren;
40use crate::vtable::validity_to_child;
41vtable!(Masked);
42
43#[derive(Debug)]
44pub struct MaskedVTable;
45
46impl MaskedVTable {
47    pub const ID: ArrayId = ArrayId::new_ref("vortex.masked");
48}
49
50impl VTable for MaskedVTable {
51    type Array = MaskedArray;
52
53    type Metadata = EmptyMetadata;
54    type OperationsVTable = Self;
55    type ValidityVTable = ValidityVTableFromValidityHelper;
56
57    fn id(_array: &Self::Array) -> ArrayId {
58        Self::ID
59    }
60
61    fn len(array: &MaskedArray) -> usize {
62        array.child.len()
63    }
64
65    fn dtype(array: &MaskedArray) -> &DType {
66        &array.dtype
67    }
68
69    fn stats(array: &MaskedArray) -> StatsSetRef<'_> {
70        array.stats.to_ref(array.as_ref())
71    }
72
73    fn array_hash<H: std::hash::Hasher>(array: &MaskedArray, state: &mut H, precision: Precision) {
74        array.child.array_hash(state, precision);
75        array.validity.array_hash(state, precision);
76        array.dtype.hash(state);
77    }
78
79    fn array_eq(array: &MaskedArray, other: &MaskedArray, precision: Precision) -> bool {
80        array.child.array_eq(&other.child, precision)
81            && array.validity.array_eq(&other.validity, precision)
82            && array.dtype == other.dtype
83    }
84
85    fn nbuffers(_array: &Self::Array) -> usize {
86        0
87    }
88
89    fn buffer(_array: &Self::Array, _idx: usize) -> BufferHandle {
90        vortex_panic!("MaskedArray has no buffers")
91    }
92
93    fn buffer_name(_array: &Self::Array, _idx: usize) -> Option<String> {
94        None
95    }
96
97    fn nchildren(array: &Self::Array) -> usize {
98        1 + validity_nchildren(&array.validity)
99    }
100
101    fn child(array: &Self::Array, idx: usize) -> ArrayRef {
102        match idx {
103            0 => array.child.clone(),
104            1 => validity_to_child(&array.validity, array.child.len())
105                .vortex_expect("MaskedArray validity child out of bounds"),
106            _ => vortex_panic!("MaskedArray child index {idx} out of bounds"),
107        }
108    }
109
110    fn child_name(_array: &Self::Array, idx: usize) -> String {
111        match idx {
112            0 => "child".to_string(),
113            1 => "validity".to_string(),
114            _ => vortex_panic!("MaskedArray child_name index {idx} out of bounds"),
115        }
116    }
117
118    fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
119        Ok(EmptyMetadata)
120    }
121
122    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
123        Ok(Some(vec![]))
124    }
125
126    fn deserialize(
127        _bytes: &[u8],
128        _dtype: &DType,
129        _len: usize,
130        _buffers: &[BufferHandle],
131        _session: &VortexSession,
132    ) -> VortexResult<Self::Metadata> {
133        Ok(EmptyMetadata)
134    }
135
136    fn build(
137        dtype: &DType,
138        len: usize,
139        _metadata: &Self::Metadata,
140        buffers: &[BufferHandle],
141        children: &dyn ArrayChildren,
142    ) -> VortexResult<MaskedArray> {
143        if !buffers.is_empty() {
144            vortex_bail!("Expected 0 buffer, got {}", buffers.len());
145        }
146
147        let child = children.get(0, &dtype.as_nonnullable(), len)?;
148
149        let validity = if children.len() == 1 {
150            Validity::from(dtype.nullability())
151        } else if children.len() == 2 {
152            let validity = children.get(1, &Validity::DTYPE, len)?;
153            Validity::Array(validity)
154        } else {
155            vortex_bail!(
156                "`MaskedArray::build` expects 1 or 2 children, got {}",
157                children.len()
158            );
159        };
160
161        MaskedArray::try_new(child, validity)
162    }
163
164    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
165        let validity_mask = array.validity_mask()?;
166
167        // Fast path: all masked means result is all nulls.
168        if validity_mask.all_false() {
169            return Ok(ExecutionStep::Done(
170                ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
171                    .into_array(),
172            ));
173        }
174
175        // NB: We intentionally do NOT have a fast path for `validity_mask.all_true()`.
176        // `MaskedArray`'s dtype is always `Nullable`, but the child has `NonNullable` `DType` (by
177        // invariant). Simply returning the child's canonical would cause a dtype mismatch.
178        // While we could manually convert the dtype, `mask_validity_canonical` is already O(1) for
179        // `AllTrue` masks (no data copying), so there's no benefit.
180
181        let child = array.child().clone().execute::<Canonical>(ctx)?;
182        Ok(ExecutionStep::Done(
183            mask_validity_canonical(child, &validity_mask, ctx)?.into_array(),
184        ))
185    }
186
187    fn reduce_parent(
188        array: &Self::Array,
189        parent: &ArrayRef,
190        child_idx: usize,
191    ) -> VortexResult<Option<ArrayRef>> {
192        PARENT_RULES.evaluate(array, parent, child_idx)
193    }
194
195    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
196        vortex_ensure!(
197            children.len() == 1 || children.len() == 2,
198            "MaskedArray expects 1 or 2 children, got {}",
199            children.len()
200        );
201
202        let mut iter = children.into_iter();
203        let child = iter
204            .next()
205            .vortex_expect("children length already validated");
206        let validity = if let Some(validity_array) = iter.next() {
207            Validity::Array(validity_array)
208        } else {
209            Validity::from(array.dtype.nullability())
210        };
211
212        let new_array = MaskedArray::try_new(child, validity)?;
213        *array = new_array;
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use rstest::rstest;
221    use vortex_buffer::ByteBufferMut;
222    use vortex_error::VortexError;
223    use vortex_session::registry::ReadContext;
224
225    use crate::ArrayContext;
226    use crate::Canonical;
227    use crate::IntoArray;
228    use crate::LEGACY_SESSION;
229    use crate::VortexSessionExecute;
230    use crate::arrays::MaskedArray;
231    use crate::arrays::MaskedVTable;
232    use crate::arrays::PrimitiveArray;
233    use crate::dtype::Nullability;
234    use crate::serde::ArrayParts;
235    use crate::serde::SerializeOptions;
236    use crate::validity::Validity;
237
238    #[rstest]
239    #[case(
240        MaskedArray::try_new(
241            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
242            Validity::AllValid
243        ).unwrap()
244    )]
245    #[case(
246        MaskedArray::try_new(
247            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
248            Validity::from_iter([true, true, false, true, false])
249        ).unwrap()
250    )]
251    #[case(
252        MaskedArray::try_new(
253            PrimitiveArray::from_iter(0..100).into_array(),
254            Validity::from_iter((0..100).map(|i| i % 3 != 0))
255        ).unwrap()
256    )]
257    fn test_serde_roundtrip(#[case] array: MaskedArray) {
258        let dtype = array.dtype().clone();
259        let len = array.len();
260
261        let ctx = ArrayContext::empty();
262        let serialized = array
263            .clone()
264            .into_array()
265            .serialize(&ctx, &SerializeOptions::default())
266            .unwrap();
267
268        // Concat into a single buffer.
269        let mut concat = ByteBufferMut::empty();
270        for buf in serialized {
271            concat.extend_from_slice(buf.as_ref());
272        }
273        let concat = concat.freeze();
274
275        let parts = ArrayParts::try_from(concat).unwrap();
276        let decoded = parts
277            .decode(
278                &dtype,
279                len,
280                &ReadContext::new(ctx.to_ids()),
281                &LEGACY_SESSION,
282            )
283            .unwrap();
284
285        assert!(decoded.is::<MaskedVTable>());
286        assert_eq!(
287            array.as_ref().display_values().to_string(),
288            decoded.display_values().to_string()
289        );
290    }
291
292    /// Regression test for issue #5989: execute_fast_path returns child with wrong dtype.
293    ///
294    /// When MaskedArray's validity mask is all true, returning the child's canonical form
295    /// directly would cause a dtype mismatch because the child has NonNullable dtype while
296    /// MaskedArray always has Nullable dtype.
297    #[test]
298    fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
299        // Create a MaskedArray with AllValid validity.
300
301        // Child has NonNullable dtype, but MaskedArray's dtype is Nullable.
302        let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
303        assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
304
305        let array = MaskedArray::try_new(child, Validity::AllValid)?;
306        assert_eq!(array.dtype().nullability(), Nullability::Nullable);
307
308        // Execute the array. This should produce a Canonical with Nullable dtype.
309        let mut ctx = LEGACY_SESSION.create_execution_ctx();
310        let result: Canonical = array.into_array().execute(&mut ctx)?;
311
312        assert_eq!(
313            result.as_ref().dtype().nullability(),
314            Nullability::Nullable,
315            "MaskedArray execute should produce Nullable dtype"
316        );
317
318        Ok(())
319    }
320}