vortex_dict/
serde.rs

1use vortex_array::serde::ArrayChildren;
2use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
3use vortex_array::{
4    Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
5};
6use vortex_buffer::ByteBuffer;
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
9
10use crate::builders::dict_encode;
11use crate::{DictArray, DictEncoding, DictVTable};
12
13#[derive(Clone, prost::Message)]
14pub struct DictMetadata {
15    #[prost(uint32, tag = "1")]
16    values_len: u32,
17    #[prost(enumeration = "PType", tag = "2")]
18    codes_ptype: i32,
19}
20
21impl SerdeVTable<DictVTable> for DictVTable {
22    type Metadata = ProstMetadata<DictMetadata>;
23
24    fn metadata(array: &DictArray) -> VortexResult<Option<Self::Metadata>> {
25        Ok(Some(ProstMetadata(DictMetadata {
26            codes_ptype: PType::try_from(array.codes().dtype())
27                .vortex_expect("Must be a valid PType") as i32,
28            values_len: u32::try_from(array.values().len()).map_err(|_| {
29                vortex_err!("Diction values cannot exceed u32 in length for serialization")
30            })?,
31        })))
32    }
33
34    fn build(
35        _encoding: &DictEncoding,
36        dtype: &DType,
37        len: usize,
38        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
39        _buffers: &[ByteBuffer],
40        children: &dyn ArrayChildren,
41    ) -> VortexResult<DictArray> {
42        if children.len() != 2 {
43            vortex_bail!(
44                "Expected 2 children for dict encoding, found {}",
45                children.len()
46            )
47        }
48        let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
49        let codes = children.get(0, &codes_dtype, len)?;
50
51        let values = children.get(1, dtype, metadata.values_len as usize)?;
52
53        DictArray::try_new(codes, values)
54    }
55}
56
57impl EncodeVTable<DictVTable> for DictVTable {
58    fn encode(
59        _encoding: &DictEncoding,
60        canonical: &Canonical,
61        _like: Option<&DictArray>,
62    ) -> VortexResult<Option<DictArray>> {
63        Ok(Some(dict_encode(canonical.as_ref())?))
64    }
65}
66
67impl VisitorVTable<DictVTable> for DictVTable {
68    fn visit_buffers(_array: &DictArray, _visitor: &mut dyn ArrayBufferVisitor) {}
69
70    fn visit_children(array: &DictArray, visitor: &mut dyn ArrayChildVisitor) {
71        visitor.visit_child("codes", array.codes());
72        visitor.visit_child("values", array.values());
73    }
74}
75
76#[cfg(test)]
77mod test {
78    use vortex_array::ProstMetadata;
79    use vortex_array::test_harness::check_metadata;
80    use vortex_dtype::PType;
81
82    use crate::serde::DictMetadata;
83
84    #[cfg_attr(miri, ignore)]
85    #[test]
86    fn test_dict_metadata() {
87        check_metadata(
88            "dict.metadata",
89            ProstMetadata(DictMetadata {
90                codes_ptype: PType::U64 as i32,
91                values_len: u32::MAX,
92            }),
93        );
94    }
95}