vortex_dict/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::serde::ArrayChildren;
5use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
6use vortex_array::{
7    Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
8};
9use vortex_buffer::ByteBuffer;
10use vortex_dtype::{DType, PType};
11use vortex_error::{VortexResult, vortex_bail, vortex_err};
12
13use crate::builders::dict_encode;
14use crate::{DictArray, DictEncoding, DictVTable};
15
16#[derive(Clone, prost::Message)]
17pub struct DictMetadata {
18    #[prost(uint32, tag = "1")]
19    values_len: u32,
20    #[prost(enumeration = "PType", tag = "2")]
21    codes_ptype: i32,
22}
23
24impl SerdeVTable<DictVTable> for DictVTable {
25    type Metadata = ProstMetadata<DictMetadata>;
26
27    fn metadata(array: &DictArray) -> VortexResult<Option<Self::Metadata>> {
28        Ok(Some(ProstMetadata(DictMetadata {
29            codes_ptype: PType::try_from(array.codes().dtype())? as i32,
30            values_len: u32::try_from(array.values().len()).map_err(|_| {
31                vortex_err!(
32                    "Dictionary values size {} overflowed u32",
33                    array.values().len()
34                )
35            })?,
36        })))
37    }
38
39    fn build(
40        _encoding: &DictEncoding,
41        dtype: &DType,
42        len: usize,
43        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
44        _buffers: &[ByteBuffer],
45        children: &dyn ArrayChildren,
46    ) -> VortexResult<DictArray> {
47        if children.len() != 2 {
48            vortex_bail!(
49                "Expected 2 children for dict encoding, found {}",
50                children.len()
51            )
52        }
53        let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
54        let codes = children.get(0, &codes_dtype, len)?;
55        let values = children.get(1, dtype, metadata.values_len as usize)?;
56
57        DictArray::try_new(codes, values)
58    }
59}
60
61impl EncodeVTable<DictVTable> for DictVTable {
62    fn encode(
63        _encoding: &DictEncoding,
64        canonical: &Canonical,
65        _like: Option<&DictArray>,
66    ) -> VortexResult<Option<DictArray>> {
67        Ok(Some(dict_encode(canonical.as_ref())?))
68    }
69}
70
71impl VisitorVTable<DictVTable> for DictVTable {
72    fn visit_buffers(_array: &DictArray, _visitor: &mut dyn ArrayBufferVisitor) {}
73
74    fn visit_children(array: &DictArray, visitor: &mut dyn ArrayChildVisitor) {
75        visitor.visit_child("codes", array.codes());
76        visitor.visit_child("values", array.values());
77    }
78}
79
80#[cfg(test)]
81mod test {
82    use vortex_array::ProstMetadata;
83    use vortex_array::test_harness::check_metadata;
84    use vortex_dtype::PType;
85
86    use crate::serde::DictMetadata;
87
88    #[cfg_attr(miri, ignore)]
89    #[test]
90    fn test_dict_metadata() {
91        check_metadata(
92            "dict.metadata",
93            ProstMetadata(DictMetadata {
94                codes_ptype: PType::U64 as i32,
95                values_len: u32::MAX,
96            }),
97        );
98    }
99}