1use vortex_array::serde::ArrayParts;
2use vortex_array::vtable::EncodingVTable;
3use vortex_array::{
4 Array, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl, Canonical,
5 DeserializeMetadata, EncodingId, RkyvMetadata,
6};
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::builders::dict_encode;
11use crate::{DictArray, DictEncoding};
12
13#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
14#[repr(C)]
15pub struct DictMetadata {
16 codes_ptype: PType,
17 values_len: u32,
18}
19
20impl EncodingVTable for DictEncoding {
21 fn id(&self) -> EncodingId {
22 EncodingId::new_ref("vortex.dict")
23 }
24
25 fn decode(
26 &self,
27 parts: &ArrayParts,
28 ctx: &ArrayContext,
29 dtype: DType,
30 len: usize,
31 ) -> VortexResult<ArrayRef> {
32 if parts.nchildren() != 2 {
33 vortex_bail!(
34 "Expected 2 children for dict encoding, found {}",
35 parts.nchildren()
36 )
37 }
38 let metadata = RkyvMetadata::<DictMetadata>::deserialize(parts.metadata())?;
39
40 let codes_dtype = DType::Primitive(metadata.codes_ptype, dtype.nullability());
41 let codes = parts.child(0).decode(ctx, codes_dtype, len)?;
42
43 let values = parts
44 .child(1)
45 .decode(ctx, dtype, metadata.values_len as usize)?;
46
47 Ok(DictArray::try_new(codes, values)?.into_array())
48 }
49
50 fn encode(
51 &self,
52 input: &Canonical,
53 _like: Option<&dyn Array>,
54 ) -> VortexResult<Option<ArrayRef>> {
55 Ok(Some(dict_encode(input.as_ref())?.into_array()))
56 }
57}
58
59impl ArrayVisitorImpl<RkyvMetadata<DictMetadata>> for DictArray {
60 fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
61 visitor.visit_child("codes", self.codes());
62 visitor.visit_child("values", self.values());
63 }
64
65 fn _metadata(&self) -> RkyvMetadata<DictMetadata> {
66 RkyvMetadata(DictMetadata {
67 codes_ptype: PType::try_from(self.codes().dtype())
68 .vortex_expect("Must be a valid PType"),
69 values_len: u32::try_from(self.values().len())
70 .vortex_expect("Values length cannot exceed u32"),
71 })
72 }
73}
74
75#[cfg(test)]
76mod test {
77 use vortex_array::RkyvMetadata;
78 use vortex_array::test_harness::check_metadata;
79 use vortex_dtype::PType;
80
81 use crate::serde::DictMetadata;
82
83 #[cfg_attr(miri, ignore)]
84 #[test]
85 fn test_dict_metadata() {
86 check_metadata(
87 "dict.metadata",
88 RkyvMetadata(DictMetadata {
89 codes_ptype: PType::U64,
90 values_len: u32::MAX,
91 }),
92 );
93 }
94}