vortex_array/arrays/masked/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod array;
5mod canonical;
6mod operations;
7mod validity;
8
9use vortex_buffer::BufferHandle;
10use vortex_dtype::DType;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_vector::Vector;
14use vortex_vector::VectorOps;
15
16use crate::ArrayBufferVisitor;
17use crate::ArrayChildVisitor;
18use crate::EmptyMetadata;
19use crate::arrays::masked::MaskedArray;
20use crate::execution::ExecutionCtx;
21use crate::serde::ArrayChildren;
22use crate::validity::Validity;
23use crate::vtable;
24use crate::vtable::ArrayId;
25use crate::vtable::ArrayVTable;
26use crate::vtable::ArrayVTableExt;
27use crate::vtable::NotSupported;
28use crate::vtable::VTable;
29use crate::vtable::ValidityVTableFromValidityHelper;
30use crate::vtable::VisitorVTable;
31
32vtable!(Masked);
33
34#[derive(Debug)]
35pub struct MaskedVTable;
36
37impl VisitorVTable<MaskedVTable> for MaskedVTable {
38    fn visit_buffers(_array: &MaskedArray, _visitor: &mut dyn ArrayBufferVisitor) {}
39
40    fn visit_children(array: &MaskedArray, visitor: &mut dyn ArrayChildVisitor) {
41        visitor.visit_child("child", array.child.as_ref());
42        visitor.visit_validity(&array.validity, array.child.len());
43    }
44}
45
46impl VTable for MaskedVTable {
47    type Array = MaskedArray;
48
49    type Metadata = EmptyMetadata;
50
51    type ArrayVTable = Self;
52    type CanonicalVTable = Self;
53    type OperationsVTable = Self;
54    type ValidityVTable = ValidityVTableFromValidityHelper;
55    type VisitorVTable = Self;
56    type ComputeVTable = NotSupported;
57    type EncodeVTable = NotSupported;
58
59    fn id(&self) -> ArrayId {
60        ArrayId::new_ref("vortex.masked")
61    }
62
63    fn encoding(_array: &Self::Array) -> ArrayVTable {
64        MaskedVTable.as_vtable()
65    }
66
67    fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
68        Ok(EmptyMetadata)
69    }
70
71    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
72        Ok(Some(vec![]))
73    }
74
75    fn deserialize(_buffer: &[u8]) -> VortexResult<Self::Metadata> {
76        Ok(EmptyMetadata)
77    }
78
79    fn build(
80        &self,
81        dtype: &DType,
82        len: usize,
83        _metadata: &Self::Metadata,
84        buffers: &[BufferHandle],
85        children: &dyn ArrayChildren,
86    ) -> VortexResult<MaskedArray> {
87        if !buffers.is_empty() {
88            vortex_bail!("Expected 0 buffer, got {}", buffers.len());
89        }
90
91        let child = children.get(0, &dtype.as_nonnullable(), len)?;
92
93        let validity = if children.len() == 1 {
94            Validity::from(dtype.nullability())
95        } else if children.len() == 2 {
96            let validity = children.get(1, &Validity::DTYPE, len)?;
97            Validity::Array(validity)
98        } else {
99            vortex_bail!(
100                "`MaskedArray::build` expects 1 or 2 children, got {}",
101                children.len()
102            );
103        };
104
105        MaskedArray::try_new(child, validity)
106    }
107
108    fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
109        let mut vector = array.child().batch_execute(ctx)?;
110        vector.mask_validity(&array.validity_mask());
111        Ok(vector)
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use rstest::rstest;
118    use vortex_buffer::ByteBufferMut;
119
120    use crate::ArrayContext;
121    use crate::IntoArray;
122    use crate::arrays::MaskedArray;
123    use crate::arrays::MaskedVTable;
124    use crate::arrays::PrimitiveArray;
125    use crate::serde::ArrayParts;
126    use crate::serde::SerializeOptions;
127    use crate::validity::Validity;
128    use crate::vtable::ArrayVTableExt;
129
130    #[rstest]
131    #[case(
132        MaskedArray::try_new(
133            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
134            Validity::AllValid
135        ).unwrap()
136    )]
137    #[case(
138        MaskedArray::try_new(
139            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
140            Validity::from_iter([true, true, false, true, false])
141        ).unwrap()
142    )]
143    #[case(
144        MaskedArray::try_new(
145            PrimitiveArray::from_iter(0..100).into_array(),
146            Validity::from_iter((0..100).map(|i| i % 3 != 0))
147        ).unwrap()
148    )]
149    fn test_serde_roundtrip(#[case] array: MaskedArray) {
150        let dtype = array.dtype().clone();
151        let len = array.len();
152        let ctx = ArrayContext::empty().with(MaskedVTable.as_vtable());
153
154        let serialized = array
155            .to_array()
156            .serialize(&ctx, &SerializeOptions::default())
157            .unwrap();
158
159        // Concat into a single buffer.
160        let mut concat = ByteBufferMut::empty();
161        for buf in serialized {
162            concat.extend_from_slice(buf.as_ref());
163        }
164        let concat = concat.freeze();
165
166        let parts = ArrayParts::try_from(concat).unwrap();
167        let decoded = parts.decode(&ctx, &dtype, len).unwrap();
168
169        assert!(decoded.is::<MaskedVTable>());
170        assert_eq!(
171            array.as_ref().display_values().to_string(),
172            decoded.display_values().to_string()
173        );
174    }
175}