vortex_array/arrays/list/
serde.rs

1use vortex_dtype::{DType, Nullability, PType};
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3
4use crate::arrays::{ListArray, ListEncoding};
5use crate::serde::ArrayParts;
6use crate::validity::Validity;
7use crate::vtable::SerdeVTable;
8use crate::{
9    Array, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl, DeserializeMetadata,
10    RkyvMetadata,
11};
12
13#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
14pub struct ListMetadata {
15    elements_len: usize,
16    offset_ptype: PType,
17}
18
19impl ArrayVisitorImpl<RkyvMetadata<ListMetadata>> for ListArray {
20    fn _children(&self, visitor: &mut dyn ArrayChildVisitor) {
21        visitor.visit_child("elements", self.elements());
22        visitor.visit_child("offsets", self.offsets());
23        visitor.visit_validity(self.validity(), self.len());
24    }
25
26    fn _metadata(&self) -> RkyvMetadata<ListMetadata> {
27        RkyvMetadata(ListMetadata {
28            elements_len: self.elements().len(),
29            offset_ptype: PType::try_from(self.offsets().dtype())
30                .vortex_expect("Must be a valid PType"),
31        })
32    }
33}
34
35impl SerdeVTable<&ListArray> for ListEncoding {
36    fn decode(
37        &self,
38        parts: &ArrayParts,
39        ctx: &ArrayContext,
40        dtype: DType,
41        len: usize,
42    ) -> VortexResult<ArrayRef> {
43        let metadata = RkyvMetadata::<ListMetadata>::deserialize(parts.metadata())?;
44
45        let validity = if parts.nchildren() == 2 {
46            Validity::from(dtype.nullability())
47        } else if parts.nchildren() == 3 {
48            let validity = parts.child(2).decode(ctx, Validity::DTYPE, len)?;
49            Validity::Array(validity)
50        } else {
51            vortex_bail!("Expected 2 or 3 children, got {}", parts.nchildren());
52        };
53
54        let DType::List(element_dtype, _) = &dtype else {
55            vortex_bail!("Expected List dtype, got {:?}", dtype);
56        };
57        let elements =
58            parts
59                .child(0)
60                .decode(ctx, element_dtype.as_ref().clone(), metadata.elements_len)?;
61
62        let offsets = parts.child(1).decode(
63            ctx,
64            DType::Primitive(metadata.offset_ptype, Nullability::NonNullable),
65            len + 1,
66        )?;
67
68        Ok(ListArray::try_new(elements, offsets, validity)?.into_array())
69    }
70}