1use fsst::{Compressor, Symbol};
2use vortex_array::arrays::VarBinArray;
3use vortex_array::serde::ArrayParts;
4use vortex_array::vtable::EncodingVTable;
5use vortex_array::{
6 Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayContext, ArrayExt, ArrayRef,
7 ArrayVisitorImpl, Canonical, DeserializeMetadata, Encoding, EncodingId, ProstMetadata,
8};
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, Nullability, PType};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
12
13use crate::{FSSTArray, FSSTEncoding, fsst_compress, fsst_train_compressor};
14
15#[derive(Clone, prost::Message)]
16pub struct FSSTMetadata {
17 #[prost(enumeration = "PType", tag = "1")]
18 uncompressed_lengths_ptype: i32,
19}
20
21impl EncodingVTable for FSSTEncoding {
22 fn id(&self) -> EncodingId {
23 EncodingId::new_ref("vortex.fsst")
24 }
25
26 fn decode(
27 &self,
28 parts: &ArrayParts,
29 ctx: &ArrayContext,
30 dtype: DType,
31 len: usize,
32 ) -> VortexResult<ArrayRef> {
33 let metadata = ProstMetadata::<FSSTMetadata>::deserialize(parts.metadata())?;
34
35 if parts.nbuffers() != 2 {
36 vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", parts.nbuffers());
37 }
38 let symbols = Buffer::<Symbol>::from_byte_buffer(parts.buffer(0)?);
39 let symbol_lengths = Buffer::<u8>::from_byte_buffer(parts.buffer(1)?);
40
41 if parts.nchildren() != 2 {
42 vortex_bail!(InvalidArgument: "Expected 2 children, got {}", parts.nchildren());
43 }
44 let codes = parts
45 .child(0)
46 .decode(ctx, DType::Binary(dtype.nullability()), len)?
47 .as_opt::<VarBinArray>()
48 .ok_or_else(|| {
49 vortex_err!(
50 "Expected VarBinArray for codes, got {:?}",
51 ctx.lookup_encoding(parts.child(0).encoding_id())
52 )
53 })?
54 .clone();
55 let uncompressed_lengths = parts.child(1).decode(
56 ctx,
57 DType::Primitive(
58 metadata.uncompressed_lengths_ptype(),
59 Nullability::NonNullable,
60 ),
61 len,
62 )?;
63
64 Ok(
65 FSSTArray::try_new(dtype, symbols, symbol_lengths, codes, uncompressed_lengths)?
66 .into_array(),
67 )
68 }
69
70 fn encode(
71 &self,
72 input: &Canonical,
73 like: Option<&dyn Array>,
74 ) -> VortexResult<Option<ArrayRef>> {
75 let like = like
76 .map(|like| {
77 like.as_opt::<<Self as Encoding>::Array>().ok_or_else(|| {
78 vortex_err!(
79 "Expected {} encoded array but got {}",
80 self.id(),
81 like.encoding()
82 )
83 })
84 })
85 .transpose()?;
86
87 let array = input.clone().into_varbinview()?;
88
89 let compressor = match like {
90 Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
91 None => fsst_train_compressor(&array)?,
92 };
93
94 let fsst = fsst_compress(&array, &compressor)?;
95
96 Ok(Some(fsst.into_array()))
97 }
98}
99
100impl ArrayVisitorImpl<ProstMetadata<FSSTMetadata>> for FSSTArray {
101 fn _visit_buffers(&self, visitor: &mut dyn ArrayBufferVisitor) {
102 visitor.visit_buffer(&self.symbols().clone().into_byte_buffer());
103 visitor.visit_buffer(&self.symbol_lengths().clone().into_byte_buffer());
104 }
105
106 fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
107 visitor.visit_child("codes", self.codes());
108 visitor.visit_child("uncompressed_lengths", self.uncompressed_lengths());
109 }
110
111 fn _metadata(&self) -> ProstMetadata<FSSTMetadata> {
112 ProstMetadata(FSSTMetadata {
113 uncompressed_lengths_ptype: PType::try_from(self.uncompressed_lengths().dtype())
114 .vortex_expect("Must be a valid PType")
115 as i32,
116 })
117 }
118}
119
120#[cfg(test)]
121mod test {
122 use vortex_array::ProstMetadata;
123 use vortex_array::test_harness::check_metadata;
124 use vortex_dtype::PType;
125
126 use crate::serde::FSSTMetadata;
127
128 #[cfg_attr(miri, ignore)]
129 #[test]
130 fn test_fsst_metadata() {
131 check_metadata(
132 "fsst.metadata",
133 ProstMetadata(FSSTMetadata {
134 uncompressed_lengths_ptype: PType::U64 as i32,
135 }),
136 );
137 }
138}