vortex_array/arrays/masked/vtable/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::ByteBuffer;
5use vortex_dtype::DType;
6use vortex_error::{VortexResult, vortex_bail};
7
8use crate::arrays::{MaskedArray, MaskedEncoding, MaskedVTable};
9use crate::serde::ArrayChildren;
10use crate::validity::Validity;
11use crate::vtable::{SerdeVTable, VisitorVTable};
12use crate::{ArrayBufferVisitor, ArrayChildVisitor, EmptyMetadata};
13
14impl SerdeVTable<MaskedVTable> for MaskedVTable {
15    type Metadata = EmptyMetadata;
16
17    fn metadata(_array: &MaskedArray) -> VortexResult<Option<Self::Metadata>> {
18        Ok(Some(EmptyMetadata))
19    }
20
21    fn build(
22        _encoding: &MaskedEncoding,
23        dtype: &DType,
24        len: usize,
25        _metadata: &Self::Metadata,
26        buffers: &[ByteBuffer],
27        children: &dyn ArrayChildren,
28    ) -> VortexResult<MaskedArray> {
29        if !buffers.is_empty() {
30            vortex_bail!("Expected 0 buffer, got {}", buffers.len());
31        }
32
33        let child = children.get(0, &dtype.as_nonnullable(), len)?;
34
35        let validity = if children.len() == 1 {
36            Validity::from(dtype.nullability())
37        } else if children.len() == 2 {
38            let validity = children.get(1, &Validity::DTYPE, len)?;
39            Validity::Array(validity)
40        } else {
41            vortex_bail!(
42                "`MaskedArray::build` expects 1 or 2 children, got {}",
43                children.len()
44            );
45        };
46
47        MaskedArray::try_new(child, validity)
48    }
49}
50
51impl VisitorVTable<MaskedVTable> for MaskedVTable {
52    fn visit_buffers(_array: &MaskedArray, _visitor: &mut dyn ArrayBufferVisitor) {}
53
54    fn visit_children(array: &MaskedArray, visitor: &mut dyn ArrayChildVisitor) {
55        visitor.visit_child("child", array.child.as_ref());
56        visitor.visit_validity(&array.validity, array.child.len());
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use rstest::rstest;
63    use vortex_buffer::ByteBufferMut;
64
65    use super::*;
66    use crate::arrays::{MaskedArray, PrimitiveArray};
67    use crate::serde::{ArrayParts, SerializeOptions};
68    use crate::{ArrayContext, EncodingRef, IntoArray};
69
70    #[rstest]
71    #[case(
72        MaskedArray::try_new(
73            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
74            Validity::AllValid
75        ).unwrap()
76    )]
77    #[case(
78        MaskedArray::try_new(
79            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
80            Validity::from_iter([true, true, false, true, false])
81        ).unwrap()
82    )]
83    #[case(
84        MaskedArray::try_new(
85            PrimitiveArray::from_iter(0..100).into_array(),
86            Validity::from_iter((0..100).map(|i| i % 3 != 0))
87        ).unwrap()
88    )]
89    fn test_serde_roundtrip(#[case] array: MaskedArray) {
90        let dtype = array.dtype().clone();
91        let len = array.len();
92        let ctx = ArrayContext::empty().with(EncodingRef::new_ref(MaskedEncoding.as_ref()));
93
94        let serialized = array
95            .to_array()
96            .serialize(&ctx, &SerializeOptions::default())
97            .unwrap();
98
99        // Concat into a single buffer.
100        let mut concat = ByteBufferMut::empty();
101        for buf in serialized {
102            concat.extend_from_slice(buf.as_ref());
103        }
104        let concat = concat.freeze();
105
106        let parts = ArrayParts::try_from(concat).unwrap();
107        let decoded = parts.decode(&ctx, &dtype, len).unwrap();
108
109        assert_eq!(decoded.encoding_id(), MaskedEncoding.id());
110        assert_eq!(
111            array.as_ref().display_values().to_string(),
112            decoded.display_values().to_string()
113        );
114    }
115}