vortex_array/arrays/dict/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::ByteBuffer;
5use vortex_dtype::{DType, Nullability, PType};
6use vortex_error::{VortexResult, vortex_bail, vortex_err};
7
8use super::{DictArray, DictEncoding, DictVTable};
9use crate::builders::dict::dict_encode;
10use crate::serde::ArrayChildren;
11use crate::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
12use crate::{
13    Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
14};
15
16#[derive(Clone, prost::Message)]
17pub struct DictMetadata {
18    #[prost(uint32, tag = "1")]
19    values_len: u32,
20    #[prost(enumeration = "PType", tag = "2")]
21    codes_ptype: i32,
22    // nullable codes are optional since they were added after stabilisation
23    #[prost(optional, bool, tag = "3")]
24    is_nullable_codes: Option<bool>,
25}
26
27impl SerdeVTable<DictVTable> for DictVTable {
28    type Metadata = ProstMetadata<DictMetadata>;
29
30    fn metadata(array: &DictArray) -> VortexResult<Option<Self::Metadata>> {
31        Ok(Some(ProstMetadata(DictMetadata {
32            codes_ptype: PType::try_from(array.codes().dtype())? as i32,
33            values_len: u32::try_from(array.values().len()).map_err(|_| {
34                vortex_err!(
35                    "Dictionary values size {} overflowed u32",
36                    array.values().len()
37                )
38            })?,
39            is_nullable_codes: Some(array.codes().dtype().is_nullable()),
40        })))
41    }
42
43    fn build(
44        _encoding: &DictEncoding,
45        dtype: &DType,
46        len: usize,
47        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
48        _buffers: &[ByteBuffer],
49        children: &dyn ArrayChildren,
50    ) -> VortexResult<DictArray> {
51        if children.len() != 2 {
52            vortex_bail!(
53                "Expected 2 children for dict encoding, found {}",
54                children.len()
55            )
56        }
57        let codes_nullable = metadata
58            .is_nullable_codes
59            .map(Nullability::from)
60            // If no `is_nullable_codes` metadata use the nullability of the values
61            // (and whole array) as before.
62            .unwrap_or_else(|| dtype.nullability());
63        let codes_dtype = DType::Primitive(metadata.codes_ptype(), codes_nullable);
64        let codes = children.get(0, &codes_dtype, len)?;
65        let values = children.get(1, dtype, metadata.values_len as usize)?;
66
67        DictArray::try_new(codes, values)
68    }
69}
70
71impl EncodeVTable<DictVTable> for DictVTable {
72    fn encode(
73        _encoding: &DictEncoding,
74        canonical: &Canonical,
75        _like: Option<&DictArray>,
76    ) -> VortexResult<Option<DictArray>> {
77        Ok(Some(dict_encode(canonical.as_ref())?))
78    }
79}
80
81impl VisitorVTable<DictVTable> for DictVTable {
82    fn visit_buffers(_array: &DictArray, _visitor: &mut dyn ArrayBufferVisitor) {}
83
84    fn visit_children(array: &DictArray, visitor: &mut dyn ArrayChildVisitor) {
85        visitor.visit_child("codes", array.codes());
86        visitor.visit_child("values", array.values());
87    }
88}
89
90#[cfg(test)]
91mod test {
92    use vortex_dtype::PType;
93
94    use super::DictMetadata;
95    use crate::ProstMetadata;
96    use crate::test_harness::check_metadata;
97
98    #[cfg_attr(miri, ignore)]
99    #[test]
100    fn test_dict_metadata() {
101        check_metadata(
102            "dict.metadata",
103            ProstMetadata(DictMetadata {
104                codes_ptype: PType::U64 as i32,
105                values_len: u32::MAX,
106                is_nullable_codes: None,
107            }),
108        );
109    }
110}