1use vortex_array::serde::ArrayChildren;
5use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
6use vortex_array::{
7 Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
8};
9use vortex_buffer::ByteBuffer;
10use vortex_dtype::{DType, PType};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
12
13use crate::builders::dict_encode;
14use crate::{DictArray, DictEncoding, DictVTable};
15
16#[derive(Clone, prost::Message)]
17pub struct DictMetadata {
18 #[prost(uint32, tag = "1")]
19 values_len: u32,
20 #[prost(enumeration = "PType", tag = "2")]
21 codes_ptype: i32,
22}
23
24impl SerdeVTable<DictVTable> for DictVTable {
25 type Metadata = ProstMetadata<DictMetadata>;
26
27 fn metadata(array: &DictArray) -> VortexResult<Option<Self::Metadata>> {
28 Ok(Some(ProstMetadata(DictMetadata {
29 codes_ptype: PType::try_from(array.codes().dtype())
30 .vortex_expect("Must be a valid PType") as i32,
31 values_len: u32::try_from(array.values().len()).map_err(|_| {
32 vortex_err!("Diction values cannot exceed u32 in length for serialization")
33 })?,
34 })))
35 }
36
37 fn build(
38 _encoding: &DictEncoding,
39 dtype: &DType,
40 len: usize,
41 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
42 _buffers: &[ByteBuffer],
43 children: &dyn ArrayChildren,
44 ) -> VortexResult<DictArray> {
45 if children.len() != 2 {
46 vortex_bail!(
47 "Expected 2 children for dict encoding, found {}",
48 children.len()
49 )
50 }
51 let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
52 let codes = children.get(0, &codes_dtype, len)?;
53
54 let values = children.get(1, dtype, metadata.values_len as usize)?;
55
56 DictArray::try_new(codes, values)
57 }
58}
59
60impl EncodeVTable<DictVTable> for DictVTable {
61 fn encode(
62 _encoding: &DictEncoding,
63 canonical: &Canonical,
64 _like: Option<&DictArray>,
65 ) -> VortexResult<Option<DictArray>> {
66 Ok(Some(dict_encode(canonical.as_ref())?))
67 }
68}
69
70impl VisitorVTable<DictVTable> for DictVTable {
71 fn visit_buffers(_array: &DictArray, _visitor: &mut dyn ArrayBufferVisitor) {}
72
73 fn visit_children(array: &DictArray, visitor: &mut dyn ArrayChildVisitor) {
74 visitor.visit_child("codes", array.codes());
75 visitor.visit_child("values", array.values());
76 }
77}
78
79#[cfg(test)]
80mod test {
81 use vortex_array::ProstMetadata;
82 use vortex_array::test_harness::check_metadata;
83 use vortex_dtype::PType;
84
85 use crate::serde::DictMetadata;
86
87 #[cfg_attr(miri, ignore)]
88 #[test]
89 fn test_dict_metadata() {
90 check_metadata(
91 "dict.metadata",
92 ProstMetadata(DictMetadata {
93 codes_ptype: PType::U64 as i32,
94 values_len: u32::MAX,
95 }),
96 );
97 }
98}