vortex_array/arrays/masked/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod array;
5mod canonical;
6mod operations;
7mod validity;
8
9use vortex_dtype::DType;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_vector::Vector;
15use vortex_vector::VectorOps;
16
17use crate::ArrayBufferVisitor;
18use crate::ArrayChildVisitor;
19use crate::ArrayRef;
20use crate::EmptyMetadata;
21use crate::VectorExecutor;
22use crate::arrays::masked::MaskedArray;
23use crate::buffer::BufferHandle;
24use crate::executor::ExecutionCtx;
25use crate::serde::ArrayChildren;
26use crate::validity::Validity;
27use crate::vtable;
28use crate::vtable::ArrayId;
29use crate::vtable::ArrayVTable;
30use crate::vtable::ArrayVTableExt;
31use crate::vtable::NotSupported;
32use crate::vtable::VTable;
33use crate::vtable::ValidityVTableFromValidityHelper;
34use crate::vtable::VisitorVTable;
35
36vtable!(Masked);
37
38#[derive(Debug)]
39pub struct MaskedVTable;
40
41impl VisitorVTable<MaskedVTable> for MaskedVTable {
42    fn visit_buffers(_array: &MaskedArray, _visitor: &mut dyn ArrayBufferVisitor) {}
43
44    fn visit_children(array: &MaskedArray, visitor: &mut dyn ArrayChildVisitor) {
45        visitor.visit_child("child", &array.child);
46        visitor.visit_validity(&array.validity, array.child.len());
47    }
48}
49
50impl VTable for MaskedVTable {
51    type Array = MaskedArray;
52
53    type Metadata = EmptyMetadata;
54
55    type ArrayVTable = Self;
56    type CanonicalVTable = Self;
57    type OperationsVTable = Self;
58    type ValidityVTable = ValidityVTableFromValidityHelper;
59    type VisitorVTable = Self;
60    type ComputeVTable = NotSupported;
61    type EncodeVTable = NotSupported;
62
63    fn id(&self) -> ArrayId {
64        ArrayId::new_ref("vortex.masked")
65    }
66
67    fn encoding(_array: &Self::Array) -> ArrayVTable {
68        MaskedVTable.as_vtable()
69    }
70
71    fn metadata(_array: &MaskedArray) -> VortexResult<Self::Metadata> {
72        Ok(EmptyMetadata)
73    }
74
75    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
76        Ok(Some(vec![]))
77    }
78
79    fn deserialize(_buffer: &[u8]) -> VortexResult<Self::Metadata> {
80        Ok(EmptyMetadata)
81    }
82
83    fn build(
84        &self,
85        dtype: &DType,
86        len: usize,
87        _metadata: &Self::Metadata,
88        buffers: &[BufferHandle],
89        children: &dyn ArrayChildren,
90    ) -> VortexResult<MaskedArray> {
91        if !buffers.is_empty() {
92            vortex_bail!("Expected 0 buffer, got {}", buffers.len());
93        }
94
95        let child = children.get(0, &dtype.as_nonnullable(), len)?;
96
97        let validity = if children.len() == 1 {
98            Validity::from(dtype.nullability())
99        } else if children.len() == 2 {
100            let validity = children.get(1, &Validity::DTYPE, len)?;
101            Validity::Array(validity)
102        } else {
103            vortex_bail!(
104                "`MaskedArray::build` expects 1 or 2 children, got {}",
105                children.len()
106            );
107        };
108
109        MaskedArray::try_new(child, validity)
110    }
111
112    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
113        let mut child = array.child().execute(ctx)?;
114        let validity_mask = array.validity_mask();
115
116        child.mask_validity(&validity_mask);
117        Ok(child)
118    }
119
120    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
121        vortex_ensure!(
122            children.len() == 1 || children.len() == 2,
123            "MaskedArray expects 1 or 2 children, got {}",
124            children.len()
125        );
126
127        let mut iter = children.into_iter();
128        let child = iter
129            .next()
130            .vortex_expect("children length already validated");
131        let validity = if let Some(validity_array) = iter.next() {
132            Validity::Array(validity_array)
133        } else {
134            Validity::from(array.dtype.nullability())
135        };
136
137        let new_array = MaskedArray::try_new(child, validity)?;
138        *array = new_array;
139        Ok(())
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use rstest::rstest;
146    use vortex_buffer::ByteBufferMut;
147
148    use crate::ArrayContext;
149    use crate::IntoArray;
150    use crate::arrays::MaskedArray;
151    use crate::arrays::MaskedVTable;
152    use crate::arrays::PrimitiveArray;
153    use crate::serde::ArrayParts;
154    use crate::serde::SerializeOptions;
155    use crate::validity::Validity;
156    use crate::vtable::ArrayVTableExt;
157
158    #[rstest]
159    #[case(
160        MaskedArray::try_new(
161            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
162            Validity::AllValid
163        ).unwrap()
164    )]
165    #[case(
166        MaskedArray::try_new(
167            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
168            Validity::from_iter([true, true, false, true, false])
169        ).unwrap()
170    )]
171    #[case(
172        MaskedArray::try_new(
173            PrimitiveArray::from_iter(0..100).into_array(),
174            Validity::from_iter((0..100).map(|i| i % 3 != 0))
175        ).unwrap()
176    )]
177    fn test_serde_roundtrip(#[case] array: MaskedArray) {
178        let dtype = array.dtype().clone();
179        let len = array.len();
180        let ctx = ArrayContext::empty().with(MaskedVTable.as_vtable());
181
182        let serialized = array
183            .to_array()
184            .serialize(&ctx, &SerializeOptions::default())
185            .unwrap();
186
187        // Concat into a single buffer.
188        let mut concat = ByteBufferMut::empty();
189        for buf in serialized {
190            concat.extend_from_slice(buf.as_ref());
191        }
192        let concat = concat.freeze();
193
194        let parts = ArrayParts::try_from(concat).unwrap();
195        let decoded = parts.decode(&ctx, &dtype, len).unwrap();
196
197        assert!(decoded.is::<MaskedVTable>());
198        assert_eq!(
199            array.as_ref().display_values().to_string(),
200            decoded.display_values().to_string()
201        );
202    }
203}