vortex_dict/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::serde::ArrayChildren;
5use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
6use vortex_array::{
7    Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
8};
9use vortex_buffer::ByteBuffer;
10use vortex_dtype::{DType, PType};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
12
13use crate::builders::dict_encode;
14use crate::{DictArray, DictEncoding, DictVTable};
15
16#[derive(Clone, prost::Message)]
17pub struct DictMetadata {
18    #[prost(uint32, tag = "1")]
19    values_len: u32,
20    #[prost(enumeration = "PType", tag = "2")]
21    codes_ptype: i32,
22}
23
24impl SerdeVTable<DictVTable> for DictVTable {
25    type Metadata = ProstMetadata<DictMetadata>;
26
27    fn metadata(array: &DictArray) -> VortexResult<Option<Self::Metadata>> {
28        Ok(Some(ProstMetadata(DictMetadata {
29            codes_ptype: PType::try_from(array.codes().dtype())
30                .vortex_expect("Must be a valid PType") as i32,
31            values_len: u32::try_from(array.values().len()).map_err(|_| {
32                vortex_err!("Diction values cannot exceed u32 in length for serialization")
33            })?,
34        })))
35    }
36
37    fn build(
38        _encoding: &DictEncoding,
39        dtype: &DType,
40        len: usize,
41        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
42        _buffers: &[ByteBuffer],
43        children: &dyn ArrayChildren,
44    ) -> VortexResult<DictArray> {
45        if children.len() != 2 {
46            vortex_bail!(
47                "Expected 2 children for dict encoding, found {}",
48                children.len()
49            )
50        }
51        let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
52        let codes = children.get(0, &codes_dtype, len)?;
53
54        let values = children.get(1, dtype, metadata.values_len as usize)?;
55
56        DictArray::try_new(codes, values)
57    }
58}
59
60impl EncodeVTable<DictVTable> for DictVTable {
61    fn encode(
62        _encoding: &DictEncoding,
63        canonical: &Canonical,
64        _like: Option<&DictArray>,
65    ) -> VortexResult<Option<DictArray>> {
66        Ok(Some(dict_encode(canonical.as_ref())?))
67    }
68}
69
70impl VisitorVTable<DictVTable> for DictVTable {
71    fn visit_buffers(_array: &DictArray, _visitor: &mut dyn ArrayBufferVisitor) {}
72
73    fn visit_children(array: &DictArray, visitor: &mut dyn ArrayChildVisitor) {
74        visitor.visit_child("codes", array.codes());
75        visitor.visit_child("values", array.values());
76    }
77}
78
79#[cfg(test)]
80mod test {
81    use vortex_array::ProstMetadata;
82    use vortex_array::test_harness::check_metadata;
83    use vortex_dtype::PType;
84
85    use crate::serde::DictMetadata;
86
87    #[cfg_attr(miri, ignore)]
88    #[test]
89    fn test_dict_metadata() {
90        check_metadata(
91            "dict.metadata",
92            ProstMetadata(DictMetadata {
93                codes_ptype: PType::U64 as i32,
94                values_len: u32::MAX,
95            }),
96        );
97    }
98}