1use vortex_array::serde::ArrayParts;
2use vortex_array::vtable::SerdeVTable;
3use vortex_array::{
4 Array, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl, DeserializeMetadata,
5 RkyvMetadata,
6};
7use vortex_dtype::{DType, PType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::{DictArray, DictEncoding};
11
12#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
13#[repr(C)]
14pub struct DictMetadata {
15 codes_ptype: PType,
16 values_len: u32,
17}
18
19impl ArrayVisitorImpl<RkyvMetadata<DictMetadata>> for DictArray {
20 fn _children(&self, visitor: &mut dyn ArrayChildVisitor) {
21 visitor.visit_child("codes", self.codes());
22 visitor.visit_child("values", self.values());
23 }
24
25 fn _metadata(&self) -> RkyvMetadata<DictMetadata> {
26 RkyvMetadata(DictMetadata {
27 codes_ptype: PType::try_from(self.codes().dtype())
28 .vortex_expect("Must be a valid PType"),
29 values_len: u32::try_from(self.values().len())
30 .vortex_expect("Values length cannot exceed u32"),
31 })
32 }
33}
34
35impl SerdeVTable<&DictArray> for DictEncoding {
36 fn decode(
37 &self,
38 parts: &ArrayParts,
39 ctx: &ArrayContext,
40 dtype: DType,
41 len: usize,
42 ) -> VortexResult<ArrayRef> {
43 if parts.nchildren() != 2 {
44 vortex_bail!(
45 "Expected 2 children for dict encoding, found {}",
46 parts.nchildren()
47 )
48 }
49 let metadata = RkyvMetadata::<DictMetadata>::deserialize(parts.metadata())?;
50
51 let codes_dtype = DType::Primitive(metadata.codes_ptype, dtype.nullability());
52 let codes = parts.child(0).decode(ctx, codes_dtype, len)?;
53
54 let values = parts
55 .child(1)
56 .decode(ctx, dtype, metadata.values_len as usize)?;
57
58 Ok(DictArray::try_new(codes, values)?.into_array())
59 }
60}
61
62#[cfg(test)]
63mod test {
64 use vortex_array::RkyvMetadata;
65 use vortex_array::test_harness::check_metadata;
66 use vortex_dtype::PType;
67
68 use crate::serde::DictMetadata;
69
70 #[cfg_attr(miri, ignore)]
71 #[test]
72 fn test_dict_metadata() {
73 check_metadata(
74 "dict.metadata",
75 RkyvMetadata(DictMetadata {
76 codes_ptype: PType::U64,
77 values_len: u32::MAX,
78 }),
79 );
80 }
81}