vortex_array/arrays/list/
serde.rs

1use vortex_dtype::{DType, Nullability, PType};
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3
4use super::{ListArray, ListEncoding};
5use crate::serde::ArrayParts;
6use crate::validity::Validity;
7use crate::vtable::EncodingVTable;
8use crate::{
9    Array, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl, DeserializeMetadata,
10    EncodingId, ProstMetadata,
11};
12
13impl EncodingVTable for ListEncoding {
14    fn id(&self) -> EncodingId {
15        EncodingId::new_ref("vortex.list")
16    }
17
18    fn decode(
19        &self,
20        parts: &ArrayParts,
21        ctx: &ArrayContext,
22        dtype: DType,
23        len: usize,
24    ) -> VortexResult<ArrayRef> {
25        let metadata = ProstMetadata::<ListMetadata>::deserialize(parts.metadata())?;
26
27        let validity = if parts.nchildren() == 2 {
28            Validity::from(dtype.nullability())
29        } else if parts.nchildren() == 3 {
30            let validity = parts.child(2).decode(ctx, Validity::DTYPE, len)?;
31            Validity::Array(validity)
32        } else {
33            vortex_bail!("Expected 2 or 3 children, got {}", parts.nchildren());
34        };
35
36        let DType::List(element_dtype, _) = &dtype else {
37            vortex_bail!("Expected List dtype, got {:?}", dtype);
38        };
39        let elements = parts.child(0).decode(
40            ctx,
41            element_dtype.as_ref().clone(),
42            usize::try_from(metadata.elements_len).vortex_expect("Too many elements"),
43        )?;
44
45        let offsets = parts.child(1).decode(
46            ctx,
47            DType::Primitive(metadata.offset_ptype(), Nullability::NonNullable),
48            len + 1,
49        )?;
50
51        Ok(ListArray::try_new(elements, offsets, validity)?.into_array())
52    }
53}
54
55#[derive(Clone, prost::Message)]
56pub struct ListMetadata {
57    #[prost(uint64, tag = "1")]
58    elements_len: u64,
59    #[prost(enumeration = "PType", tag = "2")]
60    offset_ptype: i32,
61}
62
63impl ArrayVisitorImpl<ProstMetadata<ListMetadata>> for ListArray {
64    fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
65        visitor.visit_child("elements", self.elements());
66        visitor.visit_child("offsets", self.offsets());
67        visitor.visit_validity(self.validity(), self.len());
68    }
69
70    fn _metadata(&self) -> ProstMetadata<ListMetadata> {
71        ProstMetadata(ListMetadata {
72            elements_len: u64::try_from(self.elements().len())
73                .vortex_expect("More elements than u64"),
74            offset_ptype: PType::try_from(self.offsets().dtype())
75                .vortex_expect("Must be a valid PType") as i32,
76        })
77    }
78}