1use vortex_array::serde::ArrayChildren;
2use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
3use vortex_array::{
4 Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
5};
6use vortex_buffer::ByteBuffer;
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
9
10use crate::builders::dict_encode;
11use crate::{DictArray, DictEncoding, DictVTable};
12
13#[derive(Clone, prost::Message)]
14pub struct DictMetadata {
15 #[prost(uint32, tag = "1")]
16 values_len: u32,
17 #[prost(enumeration = "PType", tag = "2")]
18 codes_ptype: i32,
19}
20
21impl SerdeVTable<DictVTable> for DictVTable {
22 type Metadata = ProstMetadata<DictMetadata>;
23
24 fn metadata(array: &DictArray) -> VortexResult<Option<Self::Metadata>> {
25 Ok(Some(ProstMetadata(DictMetadata {
26 codes_ptype: PType::try_from(array.codes().dtype())
27 .vortex_expect("Must be a valid PType") as i32,
28 values_len: u32::try_from(array.values().len()).map_err(|_| {
29 vortex_err!("Diction values cannot exceed u32 in length for serialization")
30 })?,
31 })))
32 }
33
34 fn build(
35 _encoding: &DictEncoding,
36 dtype: &DType,
37 len: usize,
38 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
39 _buffers: &[ByteBuffer],
40 children: &dyn ArrayChildren,
41 ) -> VortexResult<DictArray> {
42 if children.len() != 2 {
43 vortex_bail!(
44 "Expected 2 children for dict encoding, found {}",
45 children.len()
46 )
47 }
48 let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
49 let codes = children.get(0, &codes_dtype, len)?;
50
51 let values = children.get(1, dtype, metadata.values_len as usize)?;
52
53 DictArray::try_new(codes, values)
54 }
55}
56
57impl EncodeVTable<DictVTable> for DictVTable {
58 fn encode(
59 _encoding: &DictEncoding,
60 canonical: &Canonical,
61 _like: Option<&DictArray>,
62 ) -> VortexResult<Option<DictArray>> {
63 Ok(Some(dict_encode(canonical.as_ref())?))
64 }
65}
66
67impl VisitorVTable<DictVTable> for DictVTable {
68 fn visit_buffers(_array: &DictArray, _visitor: &mut dyn ArrayBufferVisitor) {}
69
70 fn visit_children(array: &DictArray, visitor: &mut dyn ArrayChildVisitor) {
71 visitor.visit_child("codes", array.codes());
72 visitor.visit_child("values", array.values());
73 }
74}
75
76#[cfg(test)]
77mod test {
78 use vortex_array::ProstMetadata;
79 use vortex_array::test_harness::check_metadata;
80 use vortex_dtype::PType;
81
82 use crate::serde::DictMetadata;
83
84 #[cfg_attr(miri, ignore)]
85 #[test]
86 fn test_dict_metadata() {
87 check_metadata(
88 "dict.metadata",
89 ProstMetadata(DictMetadata {
90 codes_ptype: PType::U64 as i32,
91 values_len: u32::MAX,
92 }),
93 );
94 }
95}