1use fsst::{Compressor, Symbol};
2use vortex_array::arrays::VarBinVTable;
3use vortex_array::serde::ArrayChildren;
4use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
5use vortex_array::{
6 Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
7};
8use vortex_buffer::{Buffer, ByteBuffer};
9use vortex_dtype::{DType, Nullability, PType};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
11
12use crate::{FSSTArray, FSSTEncoding, FSSTVTable, fsst_compress, fsst_train_compressor};
13
14#[derive(Clone, prost::Message)]
15pub struct FSSTMetadata {
16 #[prost(enumeration = "PType", tag = "1")]
17 uncompressed_lengths_ptype: i32,
18}
19
20impl SerdeVTable<FSSTVTable> for FSSTVTable {
21 type Metadata = ProstMetadata<FSSTMetadata>;
22
23 fn metadata(array: &FSSTArray) -> VortexResult<Option<Self::Metadata>> {
24 Ok(Some(ProstMetadata(FSSTMetadata {
25 uncompressed_lengths_ptype: PType::try_from(array.uncompressed_lengths().dtype())
26 .vortex_expect("Must be a valid PType")
27 as i32,
28 })))
29 }
30
31 fn build(
32 _encoding: &FSSTEncoding,
33 dtype: &DType,
34 len: usize,
35 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
36 buffers: &[ByteBuffer],
37 children: &dyn ArrayChildren,
38 ) -> VortexResult<FSSTArray> {
39 if buffers.len() != 2 {
40 vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len());
41 }
42 let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone());
43 let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone());
44
45 if children.len() != 2 {
46 vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
47 }
48 let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
49 let codes = codes
50 .as_opt::<VarBinVTable>()
51 .ok_or_else(|| {
52 vortex_err!(
53 "Expected VarBinArray for codes, got {}",
54 codes.encoding_id()
55 )
56 })?
57 .clone();
58 let uncompressed_lengths = children.get(
59 1,
60 &DType::Primitive(
61 metadata.uncompressed_lengths_ptype(),
62 Nullability::NonNullable,
63 ),
64 len,
65 )?;
66
67 FSSTArray::try_new(
68 dtype.clone(),
69 symbols,
70 symbol_lengths,
71 codes,
72 uncompressed_lengths,
73 )
74 }
75}
76
77impl EncodeVTable<FSSTVTable> for FSSTVTable {
78 fn encode(
79 _encoding: &FSSTEncoding,
80 canonical: &Canonical,
81 like: Option<&FSSTArray>,
82 ) -> VortexResult<Option<FSSTArray>> {
83 let array = canonical.clone().into_varbinview()?;
84
85 let compressor = match like {
86 Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
87 None => fsst_train_compressor(array.as_ref())?,
88 };
89
90 Ok(Some(fsst_compress(array.as_ref(), &compressor)?))
91 }
92}
93
94impl VisitorVTable<FSSTVTable> for FSSTVTable {
95 fn visit_buffers(array: &FSSTArray, visitor: &mut dyn ArrayBufferVisitor) {
96 visitor.visit_buffer(&array.symbols().clone().into_byte_buffer());
97 visitor.visit_buffer(&array.symbol_lengths().clone().into_byte_buffer());
98 }
99
100 fn visit_children(array: &FSSTArray, visitor: &mut dyn ArrayChildVisitor) {
101 visitor.visit_child("codes", array.codes().as_ref());
102 visitor.visit_child("uncompressed_lengths", array.uncompressed_lengths());
103 }
104}
105
106#[cfg(test)]
107mod test {
108 use vortex_array::ProstMetadata;
109 use vortex_array::test_harness::check_metadata;
110 use vortex_dtype::PType;
111
112 use crate::serde::FSSTMetadata;
113
114 #[cfg_attr(miri, ignore)]
115 #[test]
116 fn test_fsst_metadata() {
117 check_metadata(
118 "fsst.metadata",
119 ProstMetadata(FSSTMetadata {
120 uncompressed_lengths_ptype: PType::U64 as i32,
121 }),
122 );
123 }
124}