1use vortex_array::serde::ArrayParts;
2use vortex_array::vtable::EncodingVTable;
3use vortex_array::{
4 Array, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl, Canonical,
5 DeserializeMetadata, EncodingId, ProstMetadata,
6};
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::builders::dict_encode;
11use crate::{DictArray, DictEncoding};
12
13#[derive(Clone, prost::Message)]
14pub struct DictMetadata {
15 #[prost(uint32, tag = "1")]
16 values_len: u32,
17 #[prost(enumeration = "PType", tag = "2")]
18 codes_ptype: i32,
19}
20
21impl EncodingVTable for DictEncoding {
22 fn id(&self) -> EncodingId {
23 EncodingId::new_ref("vortex.dict")
24 }
25
26 fn decode(
27 &self,
28 parts: &ArrayParts,
29 ctx: &ArrayContext,
30 dtype: DType,
31 len: usize,
32 ) -> VortexResult<ArrayRef> {
33 if parts.nchildren() != 2 {
34 vortex_bail!(
35 "Expected 2 children for dict encoding, found {}",
36 parts.nchildren()
37 )
38 }
39 let metadata = ProstMetadata::<DictMetadata>::deserialize(parts.metadata())?;
40
41 let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
42 let codes = parts.child(0).decode(ctx, codes_dtype, len)?;
43
44 let values = parts
45 .child(1)
46 .decode(ctx, dtype, metadata.values_len as usize)?;
47
48 Ok(DictArray::try_new(codes, values)?.into_array())
49 }
50
51 fn encode(
52 &self,
53 input: &Canonical,
54 _like: Option<&dyn Array>,
55 ) -> VortexResult<Option<ArrayRef>> {
56 Ok(Some(dict_encode(input.as_ref())?.into_array()))
57 }
58}
59
60impl ArrayVisitorImpl<ProstMetadata<DictMetadata>> for DictArray {
61 fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
62 visitor.visit_child("codes", self.codes());
63 visitor.visit_child("values", self.values());
64 }
65
66 fn _metadata(&self) -> ProstMetadata<DictMetadata> {
67 ProstMetadata(DictMetadata {
68 codes_ptype: PType::try_from(self.codes().dtype())
69 .vortex_expect("Must be a valid PType") as i32,
70 values_len: u32::try_from(self.values().len())
71 .vortex_expect("Values length cannot exceed u32"),
72 })
73 }
74}
75
76#[cfg(test)]
77mod test {
78 use vortex_array::ProstMetadata;
79 use vortex_array::test_harness::check_metadata;
80 use vortex_dtype::PType;
81
82 use crate::serde::DictMetadata;
83
84 #[cfg_attr(miri, ignore)]
85 #[test]
86 fn test_dict_metadata() {
87 check_metadata(
88 "dict.metadata",
89 ProstMetadata(DictMetadata {
90 codes_ptype: PType::U64 as i32,
91 values_len: u32::MAX,
92 }),
93 );
94 }
95}