1use fsst::{Compressor, Symbol};
2use serde::{Deserialize, Serialize};
3use vortex_array::arrays::VarBinArray;
4use vortex_array::serde::ArrayParts;
5use vortex_array::vtable::EncodingVTable;
6use vortex_array::{
7 Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayContext, ArrayExt, ArrayRef,
8 ArrayVisitorImpl, Canonical, DeserializeMetadata, Encoding, EncodingId, SerdeMetadata,
9};
10use vortex_buffer::Buffer;
11use vortex_dtype::{DType, Nullability, PType};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
13
14use crate::{FSSTArray, FSSTEncoding, fsst_compress, fsst_train_compressor};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct FSSTMetadata {
18 uncompressed_lengths_ptype: PType,
19}
20
21impl EncodingVTable for FSSTEncoding {
22 fn id(&self) -> EncodingId {
23 EncodingId::new_ref("vortex.fsst")
24 }
25
26 fn decode(
27 &self,
28 parts: &ArrayParts,
29 ctx: &ArrayContext,
30 dtype: DType,
31 len: usize,
32 ) -> VortexResult<ArrayRef> {
33 let metadata = SerdeMetadata::<FSSTMetadata>::deserialize(parts.metadata())?;
34
35 if parts.nbuffers() != 2 {
36 vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", parts.nbuffers());
37 }
38 let symbols = Buffer::<Symbol>::from_byte_buffer(parts.buffer(0)?);
39 let symbol_lengths = Buffer::<u8>::from_byte_buffer(parts.buffer(1)?);
40
41 if parts.nchildren() != 2 {
42 vortex_bail!(InvalidArgument: "Expected 2 children, got {}", parts.nchildren());
43 }
44 let codes = parts
45 .child(0)
46 .decode(ctx, DType::Binary(dtype.nullability()), len)?
47 .as_opt::<VarBinArray>()
48 .ok_or_else(|| {
49 vortex_err!(
50 "Expected VarBinArray for codes, got {:?}",
51 ctx.lookup_encoding(parts.child(0).encoding_id())
52 )
53 })?
54 .clone();
55 let uncompressed_lengths = parts.child(1).decode(
56 ctx,
57 DType::Primitive(
58 metadata.uncompressed_lengths_ptype,
59 Nullability::NonNullable,
60 ),
61 len,
62 )?;
63
64 Ok(
65 FSSTArray::try_new(dtype, symbols, symbol_lengths, codes, uncompressed_lengths)?
66 .into_array(),
67 )
68 }
69
70 fn encode(
71 &self,
72 input: &Canonical,
73 like: Option<&dyn Array>,
74 ) -> VortexResult<Option<ArrayRef>> {
75 let like = like
76 .map(|like| {
77 like.as_opt::<<Self as Encoding>::Array>().ok_or_else(|| {
78 vortex_err!(
79 "Expected {} encoded array but got {}",
80 self.id(),
81 like.encoding()
82 )
83 })
84 })
85 .transpose()?;
86
87 let array = input.clone().into_varbinview()?;
88
89 let compressor = match like {
90 Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
91 None => fsst_train_compressor(&array)?,
92 };
93
94 let fsst = fsst_compress(&array, &compressor)?;
95
96 Ok(Some(fsst.into_array()))
97 }
98}
99
100impl ArrayVisitorImpl<SerdeMetadata<FSSTMetadata>> for FSSTArray {
101 fn _visit_buffers(&self, visitor: &mut dyn ArrayBufferVisitor) {
102 visitor.visit_buffer(&self.symbols().clone().into_byte_buffer());
103 visitor.visit_buffer(&self.symbol_lengths().clone().into_byte_buffer());
104 }
105
106 fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
107 visitor.visit_child("codes", self.codes());
108 visitor.visit_child("uncompressed_lengths", self.uncompressed_lengths());
109 }
110
111 fn _metadata(&self) -> SerdeMetadata<FSSTMetadata> {
112 SerdeMetadata(FSSTMetadata {
113 uncompressed_lengths_ptype: PType::try_from(self.uncompressed_lengths().dtype())
114 .vortex_expect("Must be a valid PType"),
115 })
116 }
117}
118
119#[cfg(test)]
120mod test {
121 use vortex_array::SerdeMetadata;
122 use vortex_array::test_harness::check_metadata;
123 use vortex_dtype::PType;
124
125 use crate::serde::FSSTMetadata;
126
127 #[cfg_attr(miri, ignore)]
128 #[test]
129 fn test_fsst_metadata() {
130 check_metadata(
131 "fsst.metadata",
132 SerdeMetadata(FSSTMetadata {
133 uncompressed_lengths_ptype: PType::U64,
134 }),
135 );
136 }
137}