1use fsst::{Compressor, Symbol};
5use vortex_array::arrays::VarBinVTable;
6use vortex_array::serde::ArrayChildren;
7use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
8use vortex_array::{
9 Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
10};
11use vortex_buffer::{Buffer, ByteBuffer};
12use vortex_dtype::{DType, Nullability, PType};
13use vortex_error::{VortexResult, vortex_bail, vortex_err};
14
15use crate::{FSSTArray, FSSTEncoding, FSSTVTable, fsst_compress, fsst_train_compressor};
16
17#[derive(Clone, prost::Message)]
18pub struct FSSTMetadata {
19 #[prost(enumeration = "PType", tag = "1")]
20 uncompressed_lengths_ptype: i32,
21}
22
23impl SerdeVTable<FSSTVTable> for FSSTVTable {
24 type Metadata = ProstMetadata<FSSTMetadata>;
25
26 fn metadata(array: &FSSTArray) -> VortexResult<Option<Self::Metadata>> {
27 Ok(Some(ProstMetadata(FSSTMetadata {
28 uncompressed_lengths_ptype: PType::try_from(array.uncompressed_lengths().dtype())?
29 as i32,
30 })))
31 }
32
33 fn build(
34 _encoding: &FSSTEncoding,
35 dtype: &DType,
36 len: usize,
37 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
38 buffers: &[ByteBuffer],
39 children: &dyn ArrayChildren,
40 ) -> VortexResult<FSSTArray> {
41 if buffers.len() != 2 {
42 vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len());
43 }
44 let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone());
45 let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone());
46
47 if children.len() != 2 {
48 vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
49 }
50 let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
51 let codes = codes
52 .as_opt::<VarBinVTable>()
53 .ok_or_else(|| {
54 vortex_err!(
55 "Expected VarBinArray for codes, got {}",
56 codes.encoding_id()
57 )
58 })?
59 .clone();
60 let uncompressed_lengths = children.get(
61 1,
62 &DType::Primitive(
63 metadata.uncompressed_lengths_ptype(),
64 Nullability::NonNullable,
65 ),
66 len,
67 )?;
68
69 FSSTArray::try_new(
70 dtype.clone(),
71 symbols,
72 symbol_lengths,
73 codes,
74 uncompressed_lengths,
75 )
76 }
77}
78
79impl EncodeVTable<FSSTVTable> for FSSTVTable {
80 fn encode(
81 _encoding: &FSSTEncoding,
82 canonical: &Canonical,
83 like: Option<&FSSTArray>,
84 ) -> VortexResult<Option<FSSTArray>> {
85 let array = canonical.clone().into_varbinview();
86
87 let compressor = match like {
88 Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
89 None => fsst_train_compressor(&array),
90 };
91
92 Ok(Some(fsst_compress(array, &compressor)))
93 }
94}
95
96impl VisitorVTable<FSSTVTable> for FSSTVTable {
97 fn visit_buffers(array: &FSSTArray, visitor: &mut dyn ArrayBufferVisitor) {
98 visitor.visit_buffer(&array.symbols().clone().into_byte_buffer());
99 visitor.visit_buffer(&array.symbol_lengths().clone().into_byte_buffer());
100 }
101
102 fn visit_children(array: &FSSTArray, visitor: &mut dyn ArrayChildVisitor) {
103 visitor.visit_child("codes", array.codes().as_ref());
104 visitor.visit_child("uncompressed_lengths", array.uncompressed_lengths());
105 }
106}
107
108#[cfg(test)]
109mod test {
110 use vortex_array::ProstMetadata;
111 use vortex_array::test_harness::check_metadata;
112 use vortex_dtype::PType;
113
114 use crate::serde::FSSTMetadata;
115
116 #[cfg_attr(miri, ignore)]
117 #[test]
118 fn test_fsst_metadata() {
119 check_metadata(
120 "fsst.metadata",
121 ProstMetadata(FSSTMetadata {
122 uncompressed_lengths_ptype: PType::U64 as i32,
123 }),
124 );
125 }
126}