vortex_dict/
serde.rs

1use vortex_array::serde::ArrayParts;
2use vortex_array::vtable::EncodingVTable;
3use vortex_array::{
4    Array, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl, Canonical,
5    DeserializeMetadata, EncodingId, ProstMetadata,
6};
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::builders::dict_encode;
11use crate::{DictArray, DictEncoding};
12
13#[derive(Clone, prost::Message)]
14pub struct DictMetadata {
15    #[prost(uint32, tag = "1")]
16    values_len: u32,
17    #[prost(enumeration = "PType", tag = "2")]
18    codes_ptype: i32,
19}
20
21impl EncodingVTable for DictEncoding {
22    fn id(&self) -> EncodingId {
23        EncodingId::new_ref("vortex.dict")
24    }
25
26    fn decode(
27        &self,
28        parts: &ArrayParts,
29        ctx: &ArrayContext,
30        dtype: DType,
31        len: usize,
32    ) -> VortexResult<ArrayRef> {
33        if parts.nchildren() != 2 {
34            vortex_bail!(
35                "Expected 2 children for dict encoding, found {}",
36                parts.nchildren()
37            )
38        }
39        let metadata = ProstMetadata::<DictMetadata>::deserialize(parts.metadata())?;
40
41        let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
42        let codes = parts.child(0).decode(ctx, codes_dtype, len)?;
43
44        let values = parts
45            .child(1)
46            .decode(ctx, dtype, metadata.values_len as usize)?;
47
48        Ok(DictArray::try_new(codes, values)?.into_array())
49    }
50
51    fn encode(
52        &self,
53        input: &Canonical,
54        _like: Option<&dyn Array>,
55    ) -> VortexResult<Option<ArrayRef>> {
56        Ok(Some(dict_encode(input.as_ref())?.into_array()))
57    }
58}
59
60impl ArrayVisitorImpl<ProstMetadata<DictMetadata>> for DictArray {
61    fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
62        visitor.visit_child("codes", self.codes());
63        visitor.visit_child("values", self.values());
64    }
65
66    fn _metadata(&self) -> ProstMetadata<DictMetadata> {
67        ProstMetadata(DictMetadata {
68            codes_ptype: PType::try_from(self.codes().dtype())
69                .vortex_expect("Must be a valid PType") as i32,
70            values_len: u32::try_from(self.values().len())
71                .vortex_expect("Values length cannot exceed u32"),
72        })
73    }
74}
75
76#[cfg(test)]
77mod test {
78    use vortex_array::ProstMetadata;
79    use vortex_array::test_harness::check_metadata;
80    use vortex_dtype::PType;
81
82    use crate::serde::DictMetadata;
83
84    #[cfg_attr(miri, ignore)]
85    #[test]
86    fn test_dict_metadata() {
87        check_metadata(
88            "dict.metadata",
89            ProstMetadata(DictMetadata {
90                codes_ptype: PType::U64 as i32,
91                values_len: u32::MAX,
92            }),
93        );
94    }
95}