Skip to main content

vortex_fsst/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::sync::LazyLock;
9
10use fsst::Compressor;
11use fsst::Decompressor;
12use fsst::Symbol;
13use vortex_array::Array;
14use vortex_array::ArrayBufferVisitor;
15use vortex_array::ArrayChildVisitor;
16use vortex_array::ArrayEq;
17use vortex_array::ArrayHash;
18use vortex_array::ArrayRef;
19use vortex_array::Canonical;
20use vortex_array::DeserializeMetadata;
21use vortex_array::ExecutionCtx;
22use vortex_array::IntoArray;
23use vortex_array::Precision;
24use vortex_array::ProstMetadata;
25use vortex_array::SerializeMetadata;
26use vortex_array::arrays::VarBinArray;
27use vortex_array::arrays::VarBinVTable;
28use vortex_array::buffer::BufferHandle;
29use vortex_array::builders::ArrayBuilder;
30use vortex_array::builders::VarBinViewBuilder;
31use vortex_array::dtype::DType;
32use vortex_array::dtype::Nullability;
33use vortex_array::dtype::PType;
34use vortex_array::serde::ArrayChildren;
35use vortex_array::stats::ArrayStats;
36use vortex_array::stats::StatsSetRef;
37use vortex_array::validity::Validity;
38use vortex_array::vtable;
39use vortex_array::vtable::ArrayId;
40use vortex_array::vtable::BaseArrayVTable;
41use vortex_array::vtable::VTable;
42use vortex_array::vtable::ValidityChild;
43use vortex_array::vtable::ValidityHelper;
44use vortex_array::vtable::ValidityVTableFromChild;
45use vortex_array::vtable::VisitorVTable;
46use vortex_array::vtable::validity_nchildren;
47use vortex_buffer::Buffer;
48use vortex_buffer::ByteBuffer;
49use vortex_error::VortexResult;
50use vortex_error::vortex_bail;
51use vortex_error::vortex_ensure;
52use vortex_error::vortex_err;
53use vortex_session::VortexSession;
54
55use crate::canonical::canonicalize_fsst;
56use crate::canonical::fsst_decode_views;
57use crate::kernel::PARENT_KERNELS;
58use crate::rules::RULES;
59
60vtable!(FSST);
61
62#[derive(Clone, prost::Message)]
63pub struct FSSTMetadata {
64    #[prost(enumeration = "PType", tag = "1")]
65    uncompressed_lengths_ptype: i32,
66
67    #[prost(enumeration = "PType", tag = "2")]
68    codes_offsets_ptype: i32,
69}
70
71impl FSSTMetadata {
72    pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
73        PType::try_from(self.uncompressed_lengths_ptype)
74            .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
75    }
76}
77
78impl VTable for FSSTVTable {
79    type Array = FSSTArray;
80
81    type Metadata = ProstMetadata<FSSTMetadata>;
82
83    type ArrayVTable = Self;
84    type OperationsVTable = Self;
85    type ValidityVTable = ValidityVTableFromChild;
86    type VisitorVTable = Self;
87
88    fn id(_array: &Self::Array) -> ArrayId {
89        Self::ID
90    }
91
92    fn metadata(array: &FSSTArray) -> VortexResult<Self::Metadata> {
93        Ok(ProstMetadata(FSSTMetadata {
94            uncompressed_lengths_ptype: array.uncompressed_lengths().dtype().as_ptype().into(),
95            codes_offsets_ptype: array.codes.offsets().dtype().as_ptype().into(),
96        }))
97    }
98
99    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
100        Ok(Some(metadata.serialize()))
101    }
102
103    fn deserialize(
104        bytes: &[u8],
105        _dtype: &DType,
106        _len: usize,
107        _buffers: &[BufferHandle],
108        _session: &VortexSession,
109    ) -> VortexResult<Self::Metadata> {
110        Ok(ProstMetadata(
111            <ProstMetadata<FSSTMetadata> as DeserializeMetadata>::deserialize(bytes)?,
112        ))
113    }
114
115    fn append_to_builder(
116        array: &FSSTArray,
117        builder: &mut dyn ArrayBuilder,
118        ctx: &mut ExecutionCtx,
119    ) -> VortexResult<()> {
120        let Some(builder) = builder.as_any_mut().downcast_mut::<VarBinViewBuilder>() else {
121            builder.extend_from_array(&array.to_array().execute::<Canonical>(ctx)?.into_array());
122            return Ok(());
123        };
124
125        // Decompress the whole block of data into a new buffer, and create some views
126        // from it instead.
127        let (buffers, views) = fsst_decode_views(array, builder.completed_block_count(), ctx)?;
128
129        builder.push_buffer_and_adjusted_views(&buffers, &views, array.validity_mask()?);
130        Ok(())
131    }
132
133    fn build(
134        dtype: &DType,
135        len: usize,
136        metadata: &Self::Metadata,
137        buffers: &[BufferHandle],
138        children: &dyn ArrayChildren,
139    ) -> VortexResult<FSSTArray> {
140        let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_host_sync()?);
141        let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_host_sync()?);
142
143        // Check for the legacy deserialization path.
144        if buffers.len() == 2 {
145            if children.len() != 2 {
146                vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
147            }
148            let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
149            let codes = codes
150                .as_opt::<VarBinVTable>()
151                .ok_or_else(|| {
152                    vortex_err!(
153                        "Expected VarBinArray for codes, got {}",
154                        codes.encoding_id()
155                    )
156                })?
157                .clone();
158            let uncompressed_lengths = children.get(
159                1,
160                &DType::Primitive(
161                    metadata.0.get_uncompressed_lengths_ptype()?,
162                    Nullability::NonNullable,
163                ),
164                len,
165            )?;
166
167            return FSSTArray::try_new(
168                dtype.clone(),
169                symbols,
170                symbol_lengths,
171                codes,
172                uncompressed_lengths,
173            );
174        }
175
176        // Check for the current deserialization path.
177        if buffers.len() == 3 {
178            let uncompressed_lengths = children.get(
179                0,
180                &DType::Primitive(
181                    metadata.0.get_uncompressed_lengths_ptype()?,
182                    Nullability::NonNullable,
183                ),
184                len,
185            )?;
186
187            let codes_buffer = ByteBuffer::from_byte_buffer(buffers[2].clone().try_to_host_sync()?);
188            let codes_offsets = children.get(
189                1,
190                &DType::Primitive(
191                    PType::try_from(metadata.codes_offsets_ptype)?,
192                    Nullability::NonNullable,
193                ),
194                // VarBin offsets are len + 1
195                len + 1,
196            )?;
197
198            let codes_validity = if children.len() == 2 {
199                Validity::from(dtype.nullability())
200            } else if children.len() == 3 {
201                let validity = children.get(2, &Validity::DTYPE, len)?;
202                Validity::Array(validity)
203            } else {
204                vortex_bail!("Expected 0 or 1 child, got {}", children.len());
205            };
206
207            let codes = VarBinArray::try_new(
208                codes_offsets,
209                codes_buffer,
210                DType::Binary(dtype.nullability()),
211                codes_validity,
212            )?;
213
214            return FSSTArray::try_new(
215                dtype.clone(),
216                symbols,
217                symbol_lengths,
218                codes,
219                uncompressed_lengths,
220            );
221        }
222
223        vortex_bail!(
224            "InvalidArgument: Expected 2 or 3 buffers, got {}",
225            buffers.len()
226        );
227    }
228
229    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
230        vortex_ensure!(
231            children.len() == 2,
232            "FSSTArray expects 2 children, got {}",
233            children.len()
234        );
235
236        let mut children_iter = children.into_iter();
237        let codes = children_iter
238            .next()
239            .ok_or_else(|| vortex_err!("FSSTArray with_children missing codes"))?;
240
241        let codes = codes
242            .as_opt::<VarBinVTable>()
243            .ok_or_else(|| {
244                vortex_err!(
245                    "Expected VarBinArray for codes, got {}",
246                    codes.encoding_id()
247                )
248            })?
249            .clone();
250        let uncompressed_lengths = children_iter
251            .next()
252            .ok_or_else(|| vortex_err!("FSSTArray with_children missing uncompressed_lengths"))?;
253
254        array.codes = codes;
255        array.uncompressed_lengths = uncompressed_lengths;
256
257        Ok(())
258    }
259
260    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
261        canonicalize_fsst(array, ctx)
262    }
263
264    fn execute_parent(
265        array: &Self::Array,
266        parent: &ArrayRef,
267        child_idx: usize,
268        ctx: &mut ExecutionCtx,
269    ) -> VortexResult<Option<ArrayRef>> {
270        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
271    }
272
273    fn reduce_parent(
274        array: &Self::Array,
275        parent: &ArrayRef,
276        child_idx: usize,
277    ) -> VortexResult<Option<ArrayRef>> {
278        RULES.evaluate(array, parent, child_idx)
279    }
280}
281
282#[derive(Clone)]
283pub struct FSSTArray {
284    dtype: DType,
285    symbols: Buffer<Symbol>,
286    symbol_lengths: Buffer<u8>,
287    codes: VarBinArray,
288    /// NOTE(ngates): this === codes, but is stored as an ArrayRef so we can return &ArrayRef!
289    codes_array: ArrayRef,
290    /// Lengths of the original values before compression, can be compressed.
291    uncompressed_lengths: ArrayRef,
292    stats_set: ArrayStats,
293
294    /// Memoized compressor used for push-down of compute by compressing the RHS.
295    compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
296}
297
298impl Debug for FSSTArray {
299    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
300        f.debug_struct("FSSTArray")
301            .field("dtype", &self.dtype)
302            .field("symbols", &self.symbols)
303            .field("symbol_lengths", &self.symbol_lengths)
304            .field("codes", &self.codes)
305            .field("uncompressed_lengths", &self.uncompressed_lengths)
306            .finish()
307    }
308}
309
310#[derive(Debug)]
311pub struct FSSTVTable;
312
313impl FSSTVTable {
314    pub const ID: ArrayId = ArrayId::new_ref("vortex.fsst");
315}
316
317impl FSSTArray {
318    /// Build an FSST array from a set of `symbols` and `codes`.
319    ///
320    /// Symbols are 8-bytes and can represent short strings, each of which is assigned
321    /// a code.
322    ///
323    /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes.
324    /// Each code corresponds either to a symbol, or to the "escape code",
325    /// which tells the decoder to emit the following byte without doing a table lookup.
326    pub fn try_new(
327        dtype: DType,
328        symbols: Buffer<Symbol>,
329        symbol_lengths: Buffer<u8>,
330        codes: VarBinArray,
331        uncompressed_lengths: ArrayRef,
332    ) -> VortexResult<Self> {
333        // Check: symbols must not have length > MAX_CODE
334        if symbols.len() > 255 {
335            vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
336        }
337        if symbols.len() != symbol_lengths.len() {
338            vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
339        }
340
341        if uncompressed_lengths.len() != codes.len() {
342            vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
343        }
344
345        if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
346            vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
347        }
348
349        // Check: strings must be a Binary array.
350        if !matches!(codes.dtype(), DType::Binary(_)) {
351            vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
352        }
353
354        // SAFETY: all components validated above
355        unsafe {
356            Ok(Self::new_unchecked(
357                dtype,
358                symbols,
359                symbol_lengths,
360                codes,
361                uncompressed_lengths,
362            ))
363        }
364    }
365
366    pub(crate) unsafe fn new_unchecked(
367        dtype: DType,
368        symbols: Buffer<Symbol>,
369        symbol_lengths: Buffer<u8>,
370        codes: VarBinArray,
371        uncompressed_lengths: ArrayRef,
372    ) -> Self {
373        let symbols2 = symbols.clone();
374        let symbol_lengths2 = symbol_lengths.clone();
375        let compressor = Arc::new(LazyLock::new(Box::new(move || {
376            Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
377        })
378            as Box<dyn Fn() -> Compressor + Send>));
379        let codes_array = codes.to_array();
380
381        Self {
382            dtype,
383            symbols,
384            symbol_lengths,
385            codes,
386            codes_array,
387            uncompressed_lengths,
388            stats_set: Default::default(),
389            compressor,
390        }
391    }
392
393    /// Access the symbol table array
394    pub fn symbols(&self) -> &Buffer<Symbol> {
395        &self.symbols
396    }
397
398    /// Access the symbol table array
399    pub fn symbol_lengths(&self) -> &Buffer<u8> {
400        &self.symbol_lengths
401    }
402
403    /// Access the codes array
404    pub fn codes(&self) -> &VarBinArray {
405        &self.codes
406    }
407
408    /// Get the DType of the codes array
409    #[inline]
410    pub fn codes_dtype(&self) -> &DType {
411        self.codes.dtype()
412    }
413
414    /// Get the uncompressed length for each element in the array.
415    pub fn uncompressed_lengths(&self) -> &ArrayRef {
416        &self.uncompressed_lengths
417    }
418
419    /// Get the DType of the uncompressed lengths array
420    #[inline]
421    pub fn uncompressed_lengths_dtype(&self) -> &DType {
422        self.uncompressed_lengths.dtype()
423    }
424
425    /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from
426    /// this array.
427    pub fn decompressor(&self) -> Decompressor<'_> {
428        Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
429    }
430
431    /// Retrieves the FSST compressor.
432    pub fn compressor(&self) -> &Compressor {
433        self.compressor.as_ref()
434    }
435}
436
437impl BaseArrayVTable<FSSTVTable> for FSSTVTable {
438    fn len(array: &FSSTArray) -> usize {
439        array.codes().len()
440    }
441
442    fn dtype(array: &FSSTArray) -> &DType {
443        &array.dtype
444    }
445
446    fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
447        array.stats_set.to_ref(array.as_ref())
448    }
449
450    fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
451        array.dtype.hash(state);
452        array.symbols.array_hash(state, precision);
453        array.symbol_lengths.array_hash(state, precision);
454        array.codes.as_ref().array_hash(state, precision);
455        array.uncompressed_lengths.array_hash(state, precision);
456    }
457
458    fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
459        array.dtype == other.dtype
460            && array.symbols.array_eq(&other.symbols, precision)
461            && array
462                .symbol_lengths
463                .array_eq(&other.symbol_lengths, precision)
464            && array
465                .codes
466                .as_ref()
467                .array_eq(other.codes.as_ref(), precision)
468            && array
469                .uncompressed_lengths
470                .array_eq(&other.uncompressed_lengths, precision)
471    }
472}
473
474impl ValidityChild<FSSTVTable> for FSSTVTable {
475    fn validity_child(array: &FSSTArray) -> &ArrayRef {
476        &array.codes_array
477    }
478}
479
480impl VisitorVTable<FSSTVTable> for FSSTVTable {
481    fn visit_buffers(array: &FSSTArray, visitor: &mut dyn ArrayBufferVisitor) {
482        visitor.visit_buffer_handle(
483            "symbols",
484            &BufferHandle::new_host(array.symbols().clone().into_byte_buffer()),
485        );
486        visitor.visit_buffer_handle(
487            "symbol_lengths",
488            &BufferHandle::new_host(array.symbol_lengths().clone().into_byte_buffer()),
489        );
490        visitor.visit_buffer_handle("compressed_codes", array.codes.bytes_handle())
491    }
492
493    fn nbuffers(_array: &FSSTArray) -> usize {
494        3
495    }
496
497    fn visit_children(array: &FSSTArray, visitor: &mut dyn ArrayChildVisitor) {
498        visitor.visit_child("uncompressed_lengths", array.uncompressed_lengths());
499        visitor.visit_child("codes_offsets", array.codes.offsets());
500        visitor.visit_validity(array.codes.validity(), array.codes.len());
501    }
502
503    fn nchildren(array: &FSSTArray) -> usize {
504        // uncompressed_lengths + codes_offsets + optional validity
505        2 + validity_nchildren(array.codes.validity())
506    }
507}
508
509#[cfg(test)]
510mod test {
511    use fsst::Compressor;
512    use fsst::Symbol;
513    use vortex_array::Array;
514    use vortex_array::IntoArray;
515    use vortex_array::LEGACY_SESSION;
516    use vortex_array::ProstMetadata;
517    use vortex_array::VortexSessionExecute;
518    use vortex_array::accessor::ArrayAccessor;
519    use vortex_array::arrays::VarBinViewArray;
520    use vortex_array::buffer::BufferHandle;
521    use vortex_array::dtype::DType;
522    use vortex_array::dtype::Nullability;
523    use vortex_array::dtype::PType;
524    use vortex_array::test_harness::check_metadata;
525    use vortex_array::vtable::VTable;
526    use vortex_buffer::Buffer;
527    use vortex_error::VortexError;
528
529    use crate::FSSTVTable;
530    use crate::array::FSSTMetadata;
531    use crate::fsst_compress_iter;
532
533    #[cfg_attr(miri, ignore)]
534    #[test]
535    fn test_fsst_metadata() {
536        check_metadata(
537            "fsst.metadata",
538            ProstMetadata(FSSTMetadata {
539                uncompressed_lengths_ptype: PType::U64 as i32,
540                codes_offsets_ptype: PType::I32 as i32,
541            }),
542        );
543    }
544
545    /// The original FSST array stored codes as a VarBinArray child and required that the child
546    /// have this encoding. Vortex forbids this kind of introspection, therefore we had to fix
547    /// the array to store the compressed offsets and compressed data buffer separately, and only
548    /// use VarBinArray to delegate behavior.
549    ///
550    /// This test manually constructs an old-style FSST array and ensures that it can still be
551    /// deserialized.
552    #[test]
553    fn test_back_compat() {
554        let symbols = Buffer::<Symbol>::copy_from([
555            Symbol::from_slice(b"abc00000"),
556            Symbol::from_slice(b"defghijk"),
557        ]);
558        let symbol_lengths = Buffer::<u8>::copy_from([3, 8]);
559
560        let compressor = Compressor::rebuild_from(symbols.as_slice(), symbol_lengths.as_slice());
561        let fsst_array = fsst_compress_iter(
562            [Some(b"abcabcab".as_ref()), Some(b"defghijk".as_ref())].into_iter(),
563            2,
564            DType::Utf8(Nullability::NonNullable),
565            &compressor,
566        );
567
568        let compressed_codes = fsst_array.codes().clone();
569
570        // There were two buffers:
571        // 1. The 8 byte symbols
572        // 2. The symbol lengths as u8.
573        let buffers = [
574            BufferHandle::new_host(symbols.into_byte_buffer()),
575            BufferHandle::new_host(symbol_lengths.into_byte_buffer()),
576        ];
577
578        // There were 2 children:
579        // 1. The compressed codes, stored as a VarBinArray.
580        // 2. The uncompressed lengths, stored as a Primitive array.
581        let children = vec![
582            compressed_codes.into_array(),
583            fsst_array.uncompressed_lengths().clone(),
584        ];
585
586        let fsst = FSSTVTable::build(
587            &DType::Utf8(Nullability::NonNullable),
588            2,
589            &ProstMetadata(FSSTMetadata {
590                uncompressed_lengths_ptype: fsst_array
591                    .uncompressed_lengths()
592                    .dtype()
593                    .as_ptype()
594                    .into(),
595                // Legacy array did not store this field, use Protobuf default of 0.
596                codes_offsets_ptype: 0,
597            }),
598            &buffers,
599            &children.as_slice(),
600        )
601        .unwrap();
602
603        let decompressed = fsst
604            .into_array()
605            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
606            .unwrap();
607        decompressed
608            .with_iterator(|it| {
609                assert_eq!(it.next().unwrap(), Some(b"abcabcab".as_ref()));
610                assert_eq!(it.next().unwrap(), Some(b"defghijk".as_ref()));
611                Ok::<_, VortexError>(())
612            })
613            .unwrap()
614    }
615}