1use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::sync::LazyLock;
9
10use fsst::Compressor;
11use fsst::Decompressor;
12use fsst::Symbol;
13use vortex_array::Array;
14use vortex_array::ArrayBufferVisitor;
15use vortex_array::ArrayChildVisitor;
16use vortex_array::ArrayEq;
17use vortex_array::ArrayHash;
18use vortex_array::ArrayRef;
19use vortex_array::Canonical;
20use vortex_array::DeserializeMetadata;
21use vortex_array::ExecutionCtx;
22use vortex_array::Precision;
23use vortex_array::ProstMetadata;
24use vortex_array::SerializeMetadata;
25use vortex_array::arrays::VarBinArray;
26use vortex_array::arrays::VarBinVTable;
27use vortex_array::buffer::BufferHandle;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::stats::ArrayStats;
30use vortex_array::stats::StatsSetRef;
31use vortex_array::vtable;
32use vortex_array::vtable::ArrayId;
33use vortex_array::vtable::ArrayVTable;
34use vortex_array::vtable::ArrayVTableExt;
35use vortex_array::vtable::BaseArrayVTable;
36use vortex_array::vtable::EncodeVTable;
37use vortex_array::vtable::NotSupported;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityChild;
40use vortex_array::vtable::ValidityVTableFromChild;
41use vortex_array::vtable::VisitorVTable;
42use vortex_buffer::Buffer;
43use vortex_dtype::DType;
44use vortex_dtype::Nullability;
45use vortex_dtype::PType;
46use vortex_error::VortexResult;
47use vortex_error::vortex_bail;
48use vortex_error::vortex_ensure;
49use vortex_error::vortex_err;
50use vortex_vector::Vector;
51
52use crate::fsst_compress;
53use crate::fsst_train_compressor;
54use crate::kernel::PARENT_KERNELS;
55
56vtable!(FSST);
57
58#[derive(Clone, prost::Message)]
59pub struct FSSTMetadata {
60 #[prost(enumeration = "PType", tag = "1")]
61 uncompressed_lengths_ptype: i32,
62}
63
64impl FSSTMetadata {
65 pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
66 PType::try_from(self.uncompressed_lengths_ptype)
67 .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
68 }
69}
70
71impl VTable for FSSTVTable {
72 type Array = FSSTArray;
73
74 type Metadata = ProstMetadata<FSSTMetadata>;
75
76 type ArrayVTable = Self;
77 type CanonicalVTable = Self;
78 type OperationsVTable = Self;
79 type ValidityVTable = ValidityVTableFromChild;
80 type VisitorVTable = Self;
81 type ComputeVTable = NotSupported;
82 type EncodeVTable = Self;
83
84 fn id(&self) -> ArrayId {
85 ArrayId::new_ref("vortex.fsst")
86 }
87
88 fn encoding(_array: &Self::Array) -> ArrayVTable {
89 FSSTVTable.as_vtable()
90 }
91
92 fn metadata(array: &FSSTArray) -> VortexResult<Self::Metadata> {
93 Ok(ProstMetadata(FSSTMetadata {
94 uncompressed_lengths_ptype: PType::try_from(array.uncompressed_lengths().dtype())?
95 as i32,
96 }))
97 }
98
99 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
100 Ok(Some(metadata.serialize()))
101 }
102
103 fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
104 Ok(ProstMetadata(
105 <ProstMetadata<FSSTMetadata> as DeserializeMetadata>::deserialize(buffer)?,
106 ))
107 }
108
109 fn build(
110 &self,
111 dtype: &DType,
112 len: usize,
113 metadata: &Self::Metadata,
114 buffers: &[BufferHandle],
115 children: &dyn ArrayChildren,
116 ) -> VortexResult<FSSTArray> {
117 if buffers.len() != 2 {
118 vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len());
119 }
120 let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_bytes()?);
121 let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_bytes()?);
122
123 if children.len() != 2 {
124 vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
125 }
126 let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
127 let codes = codes
128 .as_opt::<VarBinVTable>()
129 .ok_or_else(|| {
130 vortex_err!(
131 "Expected VarBinArray for codes, got {}",
132 codes.encoding_id()
133 )
134 })?
135 .clone();
136 let uncompressed_lengths = children.get(
137 1,
138 &DType::Primitive(
139 metadata.0.get_uncompressed_lengths_ptype()?,
140 Nullability::NonNullable,
141 ),
142 len,
143 )?;
144
145 FSSTArray::try_new(
146 dtype.clone(),
147 symbols,
148 symbol_lengths,
149 codes,
150 uncompressed_lengths,
151 )
152 }
153
154 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
155 vortex_ensure!(
156 children.len() == 2,
157 "FSSTArray expects 2 children, got {}",
158 children.len()
159 );
160
161 let mut children_iter = children.into_iter();
162 let codes = children_iter
163 .next()
164 .ok_or_else(|| vortex_err!("FSSTArray with_children missing codes"))?;
165
166 let codes = codes
167 .as_opt::<VarBinVTable>()
168 .ok_or_else(|| {
169 vortex_err!(
170 "Expected VarBinArray for codes, got {}",
171 codes.encoding_id()
172 )
173 })?
174 .clone();
175 let uncompressed_lengths = children_iter
176 .next()
177 .ok_or_else(|| vortex_err!("FSSTArray with_children missing uncompressed_lengths"))?;
178
179 array.codes = codes;
180 array.uncompressed_lengths = uncompressed_lengths;
181
182 Ok(())
183 }
184
185 fn execute_parent(
186 array: &Self::Array,
187 parent: &ArrayRef,
188 child_idx: usize,
189 ctx: &mut ExecutionCtx,
190 ) -> VortexResult<Option<Vector>> {
191 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
192 }
193}
194
195#[derive(Clone)]
196pub struct FSSTArray {
197 dtype: DType,
198 symbols: Buffer<Symbol>,
199 symbol_lengths: Buffer<u8>,
200 codes: VarBinArray,
201 codes_array: ArrayRef,
203 uncompressed_lengths: ArrayRef,
205 stats_set: ArrayStats,
206
207 compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
209}
210
211impl Debug for FSSTArray {
212 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
213 f.debug_struct("FSSTArray")
214 .field("dtype", &self.dtype)
215 .field("symbols", &self.symbols)
216 .field("symbol_lengths", &self.symbol_lengths)
217 .field("codes", &self.codes)
218 .field("uncompressed_lengths", &self.uncompressed_lengths)
219 .finish()
220 }
221}
222
223#[derive(Debug)]
224pub struct FSSTVTable;
225
226impl FSSTArray {
227 pub fn try_new(
236 dtype: DType,
237 symbols: Buffer<Symbol>,
238 symbol_lengths: Buffer<u8>,
239 codes: VarBinArray,
240 uncompressed_lengths: ArrayRef,
241 ) -> VortexResult<Self> {
242 if symbols.len() > 255 {
244 vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
245 }
246 if symbols.len() != symbol_lengths.len() {
247 vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
248 }
249
250 if uncompressed_lengths.len() != codes.len() {
251 vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
252 }
253
254 if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
255 vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
256 }
257
258 if !matches!(codes.dtype(), DType::Binary(_)) {
260 vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
261 }
262
263 unsafe {
265 Ok(Self::new_unchecked(
266 dtype,
267 symbols,
268 symbol_lengths,
269 codes,
270 uncompressed_lengths,
271 ))
272 }
273 }
274
275 pub(crate) unsafe fn new_unchecked(
276 dtype: DType,
277 symbols: Buffer<Symbol>,
278 symbol_lengths: Buffer<u8>,
279 codes: VarBinArray,
280 uncompressed_lengths: ArrayRef,
281 ) -> Self {
282 let symbols2 = symbols.clone();
283 let symbol_lengths2 = symbol_lengths.clone();
284 let compressor = Arc::new(LazyLock::new(Box::new(move || {
285 Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
286 })
287 as Box<dyn Fn() -> Compressor + Send>));
288 let codes_array = codes.to_array();
289
290 Self {
291 dtype,
292 symbols,
293 symbol_lengths,
294 codes,
295 codes_array,
296 uncompressed_lengths,
297 stats_set: Default::default(),
298 compressor,
299 }
300 }
301
302 pub fn symbols(&self) -> &Buffer<Symbol> {
304 &self.symbols
305 }
306
307 pub fn symbol_lengths(&self) -> &Buffer<u8> {
309 &self.symbol_lengths
310 }
311
312 pub fn codes(&self) -> &VarBinArray {
314 &self.codes
315 }
316
317 #[inline]
319 pub fn codes_dtype(&self) -> &DType {
320 self.codes.dtype()
321 }
322
323 pub fn uncompressed_lengths(&self) -> &ArrayRef {
325 &self.uncompressed_lengths
326 }
327
328 #[inline]
330 pub fn uncompressed_lengths_dtype(&self) -> &DType {
331 self.uncompressed_lengths.dtype()
332 }
333
334 pub fn decompressor(&self) -> Decompressor<'_> {
337 Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
338 }
339
340 pub fn compressor(&self) -> &Compressor {
342 self.compressor.as_ref()
343 }
344}
345
346impl BaseArrayVTable<FSSTVTable> for FSSTVTable {
347 fn len(array: &FSSTArray) -> usize {
348 array.codes().len()
349 }
350
351 fn dtype(array: &FSSTArray) -> &DType {
352 &array.dtype
353 }
354
355 fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
356 array.stats_set.to_ref(array.as_ref())
357 }
358
359 fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
360 array.dtype.hash(state);
361 array.symbols.array_hash(state, precision);
362 array.symbol_lengths.array_hash(state, precision);
363 array.codes.as_ref().array_hash(state, precision);
364 array.uncompressed_lengths.array_hash(state, precision);
365 }
366
367 fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
368 array.dtype == other.dtype
369 && array.symbols.array_eq(&other.symbols, precision)
370 && array
371 .symbol_lengths
372 .array_eq(&other.symbol_lengths, precision)
373 && array
374 .codes
375 .as_ref()
376 .array_eq(other.codes.as_ref(), precision)
377 && array
378 .uncompressed_lengths
379 .array_eq(&other.uncompressed_lengths, precision)
380 }
381}
382
383impl ValidityChild<FSSTVTable> for FSSTVTable {
384 fn validity_child(array: &FSSTArray) -> &ArrayRef {
385 &array.codes_array
386 }
387}
388
389impl EncodeVTable<FSSTVTable> for FSSTVTable {
390 fn encode(
391 _vtable: &FSSTVTable,
392 canonical: &Canonical,
393 like: Option<&FSSTArray>,
394 ) -> VortexResult<Option<FSSTArray>> {
395 let array = canonical.clone().into_varbinview();
396
397 let compressor = match like {
398 Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
399 None => fsst_train_compressor(&array),
400 };
401
402 Ok(Some(fsst_compress(array, &compressor)))
403 }
404}
405
406impl VisitorVTable<FSSTVTable> for FSSTVTable {
407 fn visit_buffers(array: &FSSTArray, visitor: &mut dyn ArrayBufferVisitor) {
408 visitor.visit_buffer(&array.symbols().clone().into_byte_buffer());
409 visitor.visit_buffer(&array.symbol_lengths().clone().into_byte_buffer());
410 }
411
412 fn visit_children(array: &FSSTArray, visitor: &mut dyn ArrayChildVisitor) {
413 visitor.visit_child("codes", &array.codes().to_array());
414 visitor.visit_child("uncompressed_lengths", array.uncompressed_lengths());
415 }
416}
417
418#[cfg(test)]
419mod test {
420 use vortex_array::ProstMetadata;
421 use vortex_array::test_harness::check_metadata;
422 use vortex_dtype::PType;
423
424 use crate::array::FSSTMetadata;
425
426 #[cfg_attr(miri, ignore)]
427 #[test]
428 fn test_fsst_metadata() {
429 check_metadata(
430 "fsst.metadata",
431 ProstMetadata(FSSTMetadata {
432 uncompressed_lengths_ptype: PType::U64 as i32,
433 }),
434 );
435 }
436}