vortex_fsst/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::sync::LazyLock;
9
10use fsst::Compressor;
11use fsst::Decompressor;
12use fsst::Symbol;
13use vortex_array::Array;
14use vortex_array::ArrayBufferVisitor;
15use vortex_array::ArrayChildVisitor;
16use vortex_array::ArrayEq;
17use vortex_array::ArrayHash;
18use vortex_array::ArrayRef;
19use vortex_array::Canonical;
20use vortex_array::DeserializeMetadata;
21use vortex_array::ExecutionCtx;
22use vortex_array::Precision;
23use vortex_array::ProstMetadata;
24use vortex_array::SerializeMetadata;
25use vortex_array::arrays::VarBinArray;
26use vortex_array::arrays::VarBinVTable;
27use vortex_array::buffer::BufferHandle;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::stats::ArrayStats;
30use vortex_array::stats::StatsSetRef;
31use vortex_array::vtable;
32use vortex_array::vtable::ArrayId;
33use vortex_array::vtable::ArrayVTable;
34use vortex_array::vtable::ArrayVTableExt;
35use vortex_array::vtable::BaseArrayVTable;
36use vortex_array::vtable::EncodeVTable;
37use vortex_array::vtable::NotSupported;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityChild;
40use vortex_array::vtable::ValidityVTableFromChild;
41use vortex_array::vtable::VisitorVTable;
42use vortex_buffer::Buffer;
43use vortex_dtype::DType;
44use vortex_dtype::Nullability;
45use vortex_dtype::PType;
46use vortex_error::VortexResult;
47use vortex_error::vortex_bail;
48use vortex_error::vortex_ensure;
49use vortex_error::vortex_err;
50use vortex_vector::Vector;
51
52use crate::fsst_compress;
53use crate::fsst_train_compressor;
54use crate::kernel::PARENT_KERNELS;
55
56vtable!(FSST);
57
58#[derive(Clone, prost::Message)]
59pub struct FSSTMetadata {
60    #[prost(enumeration = "PType", tag = "1")]
61    uncompressed_lengths_ptype: i32,
62}
63
64impl FSSTMetadata {
65    pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
66        PType::try_from(self.uncompressed_lengths_ptype)
67            .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
68    }
69}
70
71impl VTable for FSSTVTable {
72    type Array = FSSTArray;
73
74    type Metadata = ProstMetadata<FSSTMetadata>;
75
76    type ArrayVTable = Self;
77    type CanonicalVTable = Self;
78    type OperationsVTable = Self;
79    type ValidityVTable = ValidityVTableFromChild;
80    type VisitorVTable = Self;
81    type ComputeVTable = NotSupported;
82    type EncodeVTable = Self;
83
84    fn id(&self) -> ArrayId {
85        ArrayId::new_ref("vortex.fsst")
86    }
87
88    fn encoding(_array: &Self::Array) -> ArrayVTable {
89        FSSTVTable.as_vtable()
90    }
91
92    fn metadata(array: &FSSTArray) -> VortexResult<Self::Metadata> {
93        Ok(ProstMetadata(FSSTMetadata {
94            uncompressed_lengths_ptype: PType::try_from(array.uncompressed_lengths().dtype())?
95                as i32,
96        }))
97    }
98
99    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
100        Ok(Some(metadata.serialize()))
101    }
102
103    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
104        Ok(ProstMetadata(
105            <ProstMetadata<FSSTMetadata> as DeserializeMetadata>::deserialize(buffer)?,
106        ))
107    }
108
109    fn build(
110        &self,
111        dtype: &DType,
112        len: usize,
113        metadata: &Self::Metadata,
114        buffers: &[BufferHandle],
115        children: &dyn ArrayChildren,
116    ) -> VortexResult<FSSTArray> {
117        if buffers.len() != 2 {
118            vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len());
119        }
120        let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_bytes()?);
121        let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_bytes()?);
122
123        if children.len() != 2 {
124            vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
125        }
126        let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
127        let codes = codes
128            .as_opt::<VarBinVTable>()
129            .ok_or_else(|| {
130                vortex_err!(
131                    "Expected VarBinArray for codes, got {}",
132                    codes.encoding_id()
133                )
134            })?
135            .clone();
136        let uncompressed_lengths = children.get(
137            1,
138            &DType::Primitive(
139                metadata.0.get_uncompressed_lengths_ptype()?,
140                Nullability::NonNullable,
141            ),
142            len,
143        )?;
144
145        FSSTArray::try_new(
146            dtype.clone(),
147            symbols,
148            symbol_lengths,
149            codes,
150            uncompressed_lengths,
151        )
152    }
153
154    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
155        vortex_ensure!(
156            children.len() == 2,
157            "FSSTArray expects 2 children, got {}",
158            children.len()
159        );
160
161        let mut children_iter = children.into_iter();
162        let codes = children_iter
163            .next()
164            .ok_or_else(|| vortex_err!("FSSTArray with_children missing codes"))?;
165
166        let codes = codes
167            .as_opt::<VarBinVTable>()
168            .ok_or_else(|| {
169                vortex_err!(
170                    "Expected VarBinArray for codes, got {}",
171                    codes.encoding_id()
172                )
173            })?
174            .clone();
175        let uncompressed_lengths = children_iter
176            .next()
177            .ok_or_else(|| vortex_err!("FSSTArray with_children missing uncompressed_lengths"))?;
178
179        array.codes = codes;
180        array.uncompressed_lengths = uncompressed_lengths;
181
182        Ok(())
183    }
184
185    fn execute_parent(
186        array: &Self::Array,
187        parent: &ArrayRef,
188        child_idx: usize,
189        ctx: &mut ExecutionCtx,
190    ) -> VortexResult<Option<Vector>> {
191        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
192    }
193}
194
195#[derive(Clone)]
196pub struct FSSTArray {
197    dtype: DType,
198    symbols: Buffer<Symbol>,
199    symbol_lengths: Buffer<u8>,
200    codes: VarBinArray,
201    /// NOTE(ngates): this === codes, but is stored as an ArrayRef so we can return &ArrayRef!
202    codes_array: ArrayRef,
203    /// Lengths of the original values before compression, can be compressed.
204    uncompressed_lengths: ArrayRef,
205    stats_set: ArrayStats,
206
207    /// Memoized compressor used for push-down of compute by compressing the RHS.
208    compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
209}
210
211impl Debug for FSSTArray {
212    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
213        f.debug_struct("FSSTArray")
214            .field("dtype", &self.dtype)
215            .field("symbols", &self.symbols)
216            .field("symbol_lengths", &self.symbol_lengths)
217            .field("codes", &self.codes)
218            .field("uncompressed_lengths", &self.uncompressed_lengths)
219            .finish()
220    }
221}
222
223#[derive(Debug)]
224pub struct FSSTVTable;
225
226impl FSSTArray {
227    /// Build an FSST array from a set of `symbols` and `codes`.
228    ///
229    /// Symbols are 8-bytes and can represent short strings, each of which is assigned
230    /// a code.
231    ///
232    /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes.
233    /// Each code corresponds either to a symbol, or to the "escape code",
234    /// which tells the decoder to emit the following byte without doing a table lookup.
235    pub fn try_new(
236        dtype: DType,
237        symbols: Buffer<Symbol>,
238        symbol_lengths: Buffer<u8>,
239        codes: VarBinArray,
240        uncompressed_lengths: ArrayRef,
241    ) -> VortexResult<Self> {
242        // Check: symbols must not have length > MAX_CODE
243        if symbols.len() > 255 {
244            vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
245        }
246        if symbols.len() != symbol_lengths.len() {
247            vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
248        }
249
250        if uncompressed_lengths.len() != codes.len() {
251            vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
252        }
253
254        if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
255            vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
256        }
257
258        // Check: strings must be a Binary array.
259        if !matches!(codes.dtype(), DType::Binary(_)) {
260            vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
261        }
262
263        // SAFETY: all components validated above
264        unsafe {
265            Ok(Self::new_unchecked(
266                dtype,
267                symbols,
268                symbol_lengths,
269                codes,
270                uncompressed_lengths,
271            ))
272        }
273    }
274
275    pub(crate) unsafe fn new_unchecked(
276        dtype: DType,
277        symbols: Buffer<Symbol>,
278        symbol_lengths: Buffer<u8>,
279        codes: VarBinArray,
280        uncompressed_lengths: ArrayRef,
281    ) -> Self {
282        let symbols2 = symbols.clone();
283        let symbol_lengths2 = symbol_lengths.clone();
284        let compressor = Arc::new(LazyLock::new(Box::new(move || {
285            Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
286        })
287            as Box<dyn Fn() -> Compressor + Send>));
288        let codes_array = codes.to_array();
289
290        Self {
291            dtype,
292            symbols,
293            symbol_lengths,
294            codes,
295            codes_array,
296            uncompressed_lengths,
297            stats_set: Default::default(),
298            compressor,
299        }
300    }
301
302    /// Access the symbol table array
303    pub fn symbols(&self) -> &Buffer<Symbol> {
304        &self.symbols
305    }
306
307    /// Access the symbol table array
308    pub fn symbol_lengths(&self) -> &Buffer<u8> {
309        &self.symbol_lengths
310    }
311
312    /// Access the codes array
313    pub fn codes(&self) -> &VarBinArray {
314        &self.codes
315    }
316
317    /// Get the DType of the codes array
318    #[inline]
319    pub fn codes_dtype(&self) -> &DType {
320        self.codes.dtype()
321    }
322
323    /// Get the uncompressed length for each element in the array.
324    pub fn uncompressed_lengths(&self) -> &ArrayRef {
325        &self.uncompressed_lengths
326    }
327
328    /// Get the DType of the uncompressed lengths array
329    #[inline]
330    pub fn uncompressed_lengths_dtype(&self) -> &DType {
331        self.uncompressed_lengths.dtype()
332    }
333
334    /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from
335    /// this array.
336    pub fn decompressor(&self) -> Decompressor<'_> {
337        Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
338    }
339
340    /// Retrieves the FSST compressor.
341    pub fn compressor(&self) -> &Compressor {
342        self.compressor.as_ref()
343    }
344}
345
346impl BaseArrayVTable<FSSTVTable> for FSSTVTable {
347    fn len(array: &FSSTArray) -> usize {
348        array.codes().len()
349    }
350
351    fn dtype(array: &FSSTArray) -> &DType {
352        &array.dtype
353    }
354
355    fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
356        array.stats_set.to_ref(array.as_ref())
357    }
358
359    fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
360        array.dtype.hash(state);
361        array.symbols.array_hash(state, precision);
362        array.symbol_lengths.array_hash(state, precision);
363        array.codes.as_ref().array_hash(state, precision);
364        array.uncompressed_lengths.array_hash(state, precision);
365    }
366
367    fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
368        array.dtype == other.dtype
369            && array.symbols.array_eq(&other.symbols, precision)
370            && array
371                .symbol_lengths
372                .array_eq(&other.symbol_lengths, precision)
373            && array
374                .codes
375                .as_ref()
376                .array_eq(other.codes.as_ref(), precision)
377            && array
378                .uncompressed_lengths
379                .array_eq(&other.uncompressed_lengths, precision)
380    }
381}
382
383impl ValidityChild<FSSTVTable> for FSSTVTable {
384    fn validity_child(array: &FSSTArray) -> &ArrayRef {
385        &array.codes_array
386    }
387}
388
389impl EncodeVTable<FSSTVTable> for FSSTVTable {
390    fn encode(
391        _vtable: &FSSTVTable,
392        canonical: &Canonical,
393        like: Option<&FSSTArray>,
394    ) -> VortexResult<Option<FSSTArray>> {
395        let array = canonical.clone().into_varbinview();
396
397        let compressor = match like {
398            Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
399            None => fsst_train_compressor(&array),
400        };
401
402        Ok(Some(fsst_compress(array, &compressor)))
403    }
404}
405
406impl VisitorVTable<FSSTVTable> for FSSTVTable {
407    fn visit_buffers(array: &FSSTArray, visitor: &mut dyn ArrayBufferVisitor) {
408        visitor.visit_buffer(&array.symbols().clone().into_byte_buffer());
409        visitor.visit_buffer(&array.symbol_lengths().clone().into_byte_buffer());
410    }
411
412    fn visit_children(array: &FSSTArray, visitor: &mut dyn ArrayChildVisitor) {
413        visitor.visit_child("codes", &array.codes().to_array());
414        visitor.visit_child("uncompressed_lengths", array.uncompressed_lengths());
415    }
416}
417
418#[cfg(test)]
419mod test {
420    use vortex_array::ProstMetadata;
421    use vortex_array::test_harness::check_metadata;
422    use vortex_dtype::PType;
423
424    use crate::array::FSSTMetadata;
425
426    #[cfg_attr(miri, ignore)]
427    #[test]
428    fn test_fsst_metadata() {
429        check_metadata(
430            "fsst.metadata",
431            ProstMetadata(FSSTMetadata {
432                uncompressed_lengths_ptype: PType::U64 as i32,
433            }),
434        );
435    }
436}