Skip to main content

vortex_array/arrays/masked/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod canonical;
4mod operations;
5mod validity;
6
7use std::hash::Hash;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::EmptyMetadata;
19use crate::IntoArray;
20use crate::Precision;
21use crate::arrays::ConstantArray;
22use crate::arrays::masked::MaskedArray;
23use crate::arrays::masked::compute::rules::PARENT_RULES;
24use crate::arrays::masked::mask_validity_canonical;
25use crate::buffer::BufferHandle;
26use crate::dtype::DType;
27use crate::executor::ExecutionCtx;
28use crate::hash::ArrayEq;
29use crate::hash::ArrayHash;
30use crate::scalar::Scalar;
31use crate::serde::ArrayChildren;
32use crate::stats::StatsSetRef;
33use crate::validity::Validity;
34use crate::vtable;
35use crate::vtable::ArrayId;
36use crate::vtable::VTable;
37use crate::vtable::ValidityVTableFromValidityHelper;
38use crate::vtable::validity_nchildren;
39use crate::vtable::validity_to_child;
40vtable!(Masked);
41
42#[derive(Debug)]
43pub struct MaskedVTable;
44
45impl MaskedVTable {
46    pub const ID: ArrayId = ArrayId::new_ref("vortex.masked");
47}
48
49impl VTable for MaskedVTable {
50    type Array = MaskedArray;
51
52    type Metadata = EmptyMetadata;
53    type OperationsVTable = Self;
54    type ValidityVTable = ValidityVTableFromValidityHelper;
55
56    fn id(_array: &Self::Array) -> ArrayId {
57        Self::ID
58    }
59
60    fn len(array: &MaskedArray) -> usize {
61        array.child.len()
62    }
63
64    fn dtype(array: &MaskedArray) -> &DType {
65        &array.dtype
66    }
67
68    fn stats(array: &MaskedArray) -> StatsSetRef<'_> {
69        array.stats.to_ref(array.as_ref())
70    }
71
72    fn array_hash<H: std::hash::Hasher>(array: &MaskedArray, state: &mut H, precision: Precision) {
73        array.child.array_hash(state, precision);
74        array.validity.array_hash(state, precision);
75        array.dtype.hash(state);
76    }
77
78    fn array_eq(array: &MaskedArray, other: &MaskedArray, precision: Precision) -> bool {
79        array.child.array_eq(&other.child, precision)
80            && array.validity.array_eq(&other.validity, precision)
81            && array.dtype == other.dtype
82    }
83
84    fn nbuffers(_array: &Self::Array) -> usize {
85        0
86    }
87
88    fn buffer(_array: &Self::Array, _idx: usize) -> BufferHandle {
89        vortex_panic!("MaskedArray has no buffers")
90    }
91
92    fn buffer_name(_array: &Self::Array, _idx: usize) -> Option<String> {
93        None
94    }
95
96    fn nchildren(array: &Self::Array) -> usize {
97        1 + validity_nchildren(&array.validity)
98    }
99
100    fn child(array: &Self::Array, idx: usize) -> ArrayRef {
101        match idx {
102            0 => array.child.clone(),
103            1 => validity_to_child(&array.validity, array.child.len())
104                .vortex_expect("MaskedArray validity child out of bounds"),
105            _ => vortex_panic!("MaskedArray child index {idx} out of bounds"),
106        }
107    }
108
109    fn child_name(_array: &Self::Array, idx: usize) -> String {
110        match idx {
111            0 => "child".to_string(),
112            1 => "validity".to_string(),
113            _ => vortex_panic!("MaskedArray child_name index {idx} out of bounds"),
114        }
115    }
116
117    fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
118        Ok(EmptyMetadata)
119    }
120
121    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
122        Ok(Some(vec![]))
123    }
124
125    fn deserialize(
126        _bytes: &[u8],
127        _dtype: &DType,
128        _len: usize,
129        _buffers: &[BufferHandle],
130        _session: &VortexSession,
131    ) -> VortexResult<Self::Metadata> {
132        Ok(EmptyMetadata)
133    }
134
135    fn build(
136        dtype: &DType,
137        len: usize,
138        _metadata: &Self::Metadata,
139        buffers: &[BufferHandle],
140        children: &dyn ArrayChildren,
141    ) -> VortexResult<MaskedArray> {
142        if !buffers.is_empty() {
143            vortex_bail!("Expected 0 buffer, got {}", buffers.len());
144        }
145
146        let child = children.get(0, &dtype.as_nonnullable(), len)?;
147
148        let validity = if children.len() == 1 {
149            Validity::from(dtype.nullability())
150        } else if children.len() == 2 {
151            let validity = children.get(1, &Validity::DTYPE, len)?;
152            Validity::Array(validity)
153        } else {
154            vortex_bail!(
155                "`MaskedArray::build` expects 1 or 2 children, got {}",
156                children.len()
157            );
158        };
159
160        MaskedArray::try_new(child, validity)
161    }
162
163    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
164        let validity_mask = array.validity_mask()?;
165
166        // Fast path: all masked means result is all nulls.
167        if validity_mask.all_false() {
168            return Ok(
169                ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
170                    .into_array(),
171            );
172        }
173
174        // NB: We intentionally do NOT have a fast path for `validity_mask.all_true()`.
175        // `MaskedArray`'s dtype is always `Nullable`, but the child has `NonNullable` `DType` (by
176        // invariant). Simply returning the child's canonical would cause a dtype mismatch.
177        // While we could manually convert the dtype, `mask_validity_canonical` is already O(1) for
178        // `AllTrue` masks (no data copying), so there's no benefit.
179
180        let child = array.child().clone().execute::<Canonical>(ctx)?;
181        Ok(mask_validity_canonical(child, &validity_mask, ctx)?.into_array())
182    }
183
184    fn reduce_parent(
185        array: &Self::Array,
186        parent: &ArrayRef,
187        child_idx: usize,
188    ) -> VortexResult<Option<ArrayRef>> {
189        PARENT_RULES.evaluate(array, parent, child_idx)
190    }
191
192    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
193        vortex_ensure!(
194            children.len() == 1 || children.len() == 2,
195            "MaskedArray expects 1 or 2 children, got {}",
196            children.len()
197        );
198
199        let mut iter = children.into_iter();
200        let child = iter
201            .next()
202            .vortex_expect("children length already validated");
203        let validity = if let Some(validity_array) = iter.next() {
204            Validity::Array(validity_array)
205        } else {
206            Validity::from(array.dtype.nullability())
207        };
208
209        let new_array = MaskedArray::try_new(child, validity)?;
210        *array = new_array;
211        Ok(())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use rstest::rstest;
218    use vortex_buffer::ByteBufferMut;
219    use vortex_error::VortexError;
220
221    use crate::ArrayContext;
222    use crate::Canonical;
223    use crate::IntoArray;
224    use crate::LEGACY_SESSION;
225    use crate::VortexSessionExecute;
226    use crate::arrays::MaskedArray;
227    use crate::arrays::MaskedVTable;
228    use crate::arrays::PrimitiveArray;
229    use crate::dtype::Nullability;
230    use crate::serde::ArrayParts;
231    use crate::serde::SerializeOptions;
232    use crate::validity::Validity;
233
234    #[rstest]
235    #[case(
236        MaskedArray::try_new(
237            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
238            Validity::AllValid
239        ).unwrap()
240    )]
241    #[case(
242        MaskedArray::try_new(
243            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
244            Validity::from_iter([true, true, false, true, false])
245        ).unwrap()
246    )]
247    #[case(
248        MaskedArray::try_new(
249            PrimitiveArray::from_iter(0..100).into_array(),
250            Validity::from_iter((0..100).map(|i| i % 3 != 0))
251        ).unwrap()
252    )]
253    fn test_serde_roundtrip(#[case] array: MaskedArray) {
254        let dtype = array.dtype().clone();
255        let len = array.len();
256
257        let ctx = ArrayContext::empty();
258        let serialized = array
259            .to_array()
260            .serialize(&ctx, &SerializeOptions::default())
261            .unwrap();
262
263        // Concat into a single buffer.
264        let mut concat = ByteBufferMut::empty();
265        for buf in serialized {
266            concat.extend_from_slice(buf.as_ref());
267        }
268        let concat = concat.freeze();
269
270        let parts = ArrayParts::try_from(concat).unwrap();
271        let decoded = parts.decode(&dtype, len, &ctx, &LEGACY_SESSION).unwrap();
272
273        assert!(decoded.is::<MaskedVTable>());
274        assert_eq!(
275            array.as_ref().display_values().to_string(),
276            decoded.display_values().to_string()
277        );
278    }
279
280    /// Regression test for issue #5989: execute_fast_path returns child with wrong dtype.
281    ///
282    /// When MaskedArray's validity mask is all true, returning the child's canonical form
283    /// directly would cause a dtype mismatch because the child has NonNullable dtype while
284    /// MaskedArray always has Nullable dtype.
285    #[test]
286    fn test_execute_with_all_valid_preserves_nullable_dtype() -> Result<(), VortexError> {
287        // Create a MaskedArray with AllValid validity.
288
289        // Child has NonNullable dtype, but MaskedArray's dtype is Nullable.
290        let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
291        assert_eq!(child.dtype().nullability(), Nullability::NonNullable);
292
293        let array = MaskedArray::try_new(child, Validity::AllValid)?;
294        assert_eq!(array.dtype().nullability(), Nullability::Nullable);
295
296        // Execute the array. This should produce a Canonical with Nullable dtype.
297        let mut ctx = LEGACY_SESSION.create_execution_ctx();
298        let result: Canonical = array.into_array().execute(&mut ctx)?;
299
300        assert_eq!(
301            result.as_ref().dtype().nullability(),
302            Nullability::Nullable,
303            "MaskedArray execute should produce Nullable dtype"
304        );
305
306        Ok(())
307    }
308}