Skip to main content

vortex_fsst/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::sync::LazyLock;
9
10use fsst::Compressor;
11use fsst::Decompressor;
12use fsst::Symbol;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayRef;
17use vortex_array::Canonical;
18use vortex_array::DeserializeMetadata;
19use vortex_array::ExecutionCtx;
20use vortex_array::IntoArray;
21use vortex_array::Precision;
22use vortex_array::ProstMetadata;
23use vortex_array::SerializeMetadata;
24use vortex_array::arrays::VarBinArray;
25use vortex_array::arrays::VarBinVTable;
26use vortex_array::buffer::BufferHandle;
27use vortex_array::builders::ArrayBuilder;
28use vortex_array::builders::VarBinViewBuilder;
29use vortex_array::dtype::DType;
30use vortex_array::dtype::Nullability;
31use vortex_array::dtype::PType;
32use vortex_array::serde::ArrayChildren;
33use vortex_array::stats::ArrayStats;
34use vortex_array::stats::StatsSetRef;
35use vortex_array::validity::Validity;
36use vortex_array::vtable;
37use vortex_array::vtable::ArrayId;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityChild;
40use vortex_array::vtable::ValidityHelper;
41use vortex_array::vtable::ValidityVTableFromChild;
42use vortex_array::vtable::validity_nchildren;
43use vortex_array::vtable::validity_to_child;
44use vortex_buffer::Buffer;
45use vortex_buffer::ByteBuffer;
46use vortex_error::VortexResult;
47use vortex_error::vortex_bail;
48use vortex_error::vortex_ensure;
49use vortex_error::vortex_err;
50use vortex_error::vortex_panic;
51use vortex_session::VortexSession;
52
53use crate::canonical::canonicalize_fsst;
54use crate::canonical::fsst_decode_views;
55use crate::kernel::PARENT_KERNELS;
56use crate::rules::RULES;
57
58vtable!(FSST);
59
60#[derive(Clone, prost::Message)]
61pub struct FSSTMetadata {
62    #[prost(enumeration = "PType", tag = "1")]
63    uncompressed_lengths_ptype: i32,
64
65    #[prost(enumeration = "PType", tag = "2")]
66    codes_offsets_ptype: i32,
67}
68
69impl FSSTMetadata {
70    pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
71        PType::try_from(self.uncompressed_lengths_ptype)
72            .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
73    }
74}
75
76impl VTable for FSSTVTable {
77    type Array = FSSTArray;
78
79    type Metadata = ProstMetadata<FSSTMetadata>;
80    type OperationsVTable = Self;
81    type ValidityVTable = ValidityVTableFromChild;
82
83    fn id(_array: &Self::Array) -> ArrayId {
84        Self::ID
85    }
86
87    fn len(array: &FSSTArray) -> usize {
88        array.codes().len()
89    }
90
91    fn dtype(array: &FSSTArray) -> &DType {
92        &array.dtype
93    }
94
95    fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
96        array.stats_set.to_ref(array.as_ref())
97    }
98
99    fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
100        array.dtype.hash(state);
101        array.symbols.array_hash(state, precision);
102        array.symbol_lengths.array_hash(state, precision);
103        array.codes.as_ref().array_hash(state, precision);
104        array.uncompressed_lengths.array_hash(state, precision);
105    }
106
107    fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
108        array.dtype == other.dtype
109            && array.symbols.array_eq(&other.symbols, precision)
110            && array
111                .symbol_lengths
112                .array_eq(&other.symbol_lengths, precision)
113            && array
114                .codes
115                .as_ref()
116                .array_eq(other.codes.as_ref(), precision)
117            && array
118                .uncompressed_lengths
119                .array_eq(&other.uncompressed_lengths, precision)
120    }
121
122    fn nbuffers(_array: &FSSTArray) -> usize {
123        3
124    }
125
126    fn buffer(array: &FSSTArray, idx: usize) -> BufferHandle {
127        match idx {
128            0 => BufferHandle::new_host(array.symbols().clone().into_byte_buffer()),
129            1 => BufferHandle::new_host(array.symbol_lengths().clone().into_byte_buffer()),
130            2 => array.codes.bytes_handle().clone(),
131            _ => vortex_panic!("FSSTArray buffer index {idx} out of bounds"),
132        }
133    }
134
135    fn buffer_name(_array: &FSSTArray, idx: usize) -> Option<String> {
136        match idx {
137            0 => Some("symbols".to_string()),
138            1 => Some("symbol_lengths".to_string()),
139            2 => Some("compressed_codes".to_string()),
140            _ => vortex_panic!("FSSTArray buffer_name index {idx} out of bounds"),
141        }
142    }
143
144    fn nchildren(array: &FSSTArray) -> usize {
145        2 + validity_nchildren(array.codes.validity())
146    }
147
148    fn child(array: &FSSTArray, idx: usize) -> ArrayRef {
149        match idx {
150            0 => array.uncompressed_lengths().clone(),
151            1 => array.codes.offsets().clone(),
152            2 => validity_to_child(array.codes.validity(), array.codes.len())
153                .unwrap_or_else(|| vortex_panic!("FSSTArray child index {idx} out of bounds")),
154            _ => vortex_panic!("FSSTArray child index {idx} out of bounds"),
155        }
156    }
157
158    fn child_name(_array: &FSSTArray, idx: usize) -> String {
159        match idx {
160            0 => "uncompressed_lengths".to_string(),
161            1 => "codes_offsets".to_string(),
162            2 => "validity".to_string(),
163            _ => vortex_panic!("FSSTArray child_name index {idx} out of bounds"),
164        }
165    }
166
167    fn metadata(array: &FSSTArray) -> VortexResult<Self::Metadata> {
168        Ok(ProstMetadata(FSSTMetadata {
169            uncompressed_lengths_ptype: array.uncompressed_lengths().dtype().as_ptype().into(),
170            codes_offsets_ptype: array.codes.offsets().dtype().as_ptype().into(),
171        }))
172    }
173
174    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
175        Ok(Some(metadata.serialize()))
176    }
177
178    fn deserialize(
179        bytes: &[u8],
180        _dtype: &DType,
181        _len: usize,
182        _buffers: &[BufferHandle],
183        _session: &VortexSession,
184    ) -> VortexResult<Self::Metadata> {
185        Ok(ProstMetadata(
186            <ProstMetadata<FSSTMetadata> as DeserializeMetadata>::deserialize(bytes)?,
187        ))
188    }
189
190    fn append_to_builder(
191        array: &FSSTArray,
192        builder: &mut dyn ArrayBuilder,
193        ctx: &mut ExecutionCtx,
194    ) -> VortexResult<()> {
195        let Some(builder) = builder.as_any_mut().downcast_mut::<VarBinViewBuilder>() else {
196            builder.extend_from_array(&array.to_array().execute::<Canonical>(ctx)?.into_array());
197            return Ok(());
198        };
199
200        // Decompress the whole block of data into a new buffer, and create some views
201        // from it instead.
202        let (buffers, views) = fsst_decode_views(array, builder.completed_block_count(), ctx)?;
203
204        builder.push_buffer_and_adjusted_views(&buffers, &views, array.validity_mask()?);
205        Ok(())
206    }
207
208    fn build(
209        dtype: &DType,
210        len: usize,
211        metadata: &Self::Metadata,
212        buffers: &[BufferHandle],
213        children: &dyn ArrayChildren,
214    ) -> VortexResult<FSSTArray> {
215        let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_host_sync()?);
216        let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_host_sync()?);
217
218        // Check for the legacy deserialization path.
219        if buffers.len() == 2 {
220            if children.len() != 2 {
221                vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
222            }
223            let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
224            let codes = codes
225                .as_opt::<VarBinVTable>()
226                .ok_or_else(|| {
227                    vortex_err!(
228                        "Expected VarBinArray for codes, got {}",
229                        codes.encoding_id()
230                    )
231                })?
232                .clone();
233            let uncompressed_lengths = children.get(
234                1,
235                &DType::Primitive(
236                    metadata.0.get_uncompressed_lengths_ptype()?,
237                    Nullability::NonNullable,
238                ),
239                len,
240            )?;
241
242            return FSSTArray::try_new(
243                dtype.clone(),
244                symbols,
245                symbol_lengths,
246                codes,
247                uncompressed_lengths,
248            );
249        }
250
251        // Check for the current deserialization path.
252        if buffers.len() == 3 {
253            let uncompressed_lengths = children.get(
254                0,
255                &DType::Primitive(
256                    metadata.0.get_uncompressed_lengths_ptype()?,
257                    Nullability::NonNullable,
258                ),
259                len,
260            )?;
261
262            let codes_buffer = ByteBuffer::from_byte_buffer(buffers[2].clone().try_to_host_sync()?);
263            let codes_offsets = children.get(
264                1,
265                &DType::Primitive(
266                    PType::try_from(metadata.codes_offsets_ptype)?,
267                    Nullability::NonNullable,
268                ),
269                // VarBin offsets are len + 1
270                len + 1,
271            )?;
272
273            let codes_validity = if children.len() == 2 {
274                Validity::from(dtype.nullability())
275            } else if children.len() == 3 {
276                let validity = children.get(2, &Validity::DTYPE, len)?;
277                Validity::Array(validity)
278            } else {
279                vortex_bail!("Expected 0 or 1 child, got {}", children.len());
280            };
281
282            let codes = VarBinArray::try_new(
283                codes_offsets,
284                codes_buffer,
285                DType::Binary(dtype.nullability()),
286                codes_validity,
287            )?;
288
289            return FSSTArray::try_new(
290                dtype.clone(),
291                symbols,
292                symbol_lengths,
293                codes,
294                uncompressed_lengths,
295            );
296        }
297
298        vortex_bail!(
299            "InvalidArgument: Expected 2 or 3 buffers, got {}",
300            buffers.len()
301        );
302    }
303
304    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
305        vortex_ensure!(
306            children.len() == 2,
307            "FSSTArray expects 2 children, got {}",
308            children.len()
309        );
310
311        let mut children_iter = children.into_iter();
312        let codes = children_iter
313            .next()
314            .ok_or_else(|| vortex_err!("FSSTArray with_children missing codes"))?;
315
316        let codes = codes
317            .as_opt::<VarBinVTable>()
318            .ok_or_else(|| {
319                vortex_err!(
320                    "Expected VarBinArray for codes, got {}",
321                    codes.encoding_id()
322                )
323            })?
324            .clone();
325        let uncompressed_lengths = children_iter
326            .next()
327            .ok_or_else(|| vortex_err!("FSSTArray with_children missing uncompressed_lengths"))?;
328
329        array.codes = codes;
330        array.uncompressed_lengths = uncompressed_lengths;
331
332        Ok(())
333    }
334
335    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
336        canonicalize_fsst(array, ctx)
337    }
338
339    fn execute_parent(
340        array: &Self::Array,
341        parent: &ArrayRef,
342        child_idx: usize,
343        ctx: &mut ExecutionCtx,
344    ) -> VortexResult<Option<ArrayRef>> {
345        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
346    }
347
348    fn reduce_parent(
349        array: &Self::Array,
350        parent: &ArrayRef,
351        child_idx: usize,
352    ) -> VortexResult<Option<ArrayRef>> {
353        RULES.evaluate(array, parent, child_idx)
354    }
355}
356
357#[derive(Clone)]
358pub struct FSSTArray {
359    dtype: DType,
360    symbols: Buffer<Symbol>,
361    symbol_lengths: Buffer<u8>,
362    codes: VarBinArray,
363    /// NOTE(ngates): this === codes, but is stored as an ArrayRef so we can return &ArrayRef!
364    codes_array: ArrayRef,
365    /// Lengths of the original values before compression, can be compressed.
366    uncompressed_lengths: ArrayRef,
367    stats_set: ArrayStats,
368
369    /// Memoized compressor used for push-down of compute by compressing the RHS.
370    compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
371}
372
373impl Debug for FSSTArray {
374    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
375        f.debug_struct("FSSTArray")
376            .field("dtype", &self.dtype)
377            .field("symbols", &self.symbols)
378            .field("symbol_lengths", &self.symbol_lengths)
379            .field("codes", &self.codes)
380            .field("uncompressed_lengths", &self.uncompressed_lengths)
381            .finish()
382    }
383}
384
385#[derive(Debug)]
386pub struct FSSTVTable;
387
388impl FSSTVTable {
389    pub const ID: ArrayId = ArrayId::new_ref("vortex.fsst");
390}
391
392impl FSSTArray {
393    /// Build an FSST array from a set of `symbols` and `codes`.
394    ///
395    /// Symbols are 8-bytes and can represent short strings, each of which is assigned
396    /// a code.
397    ///
398    /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes.
399    /// Each code corresponds either to a symbol, or to the "escape code",
400    /// which tells the decoder to emit the following byte without doing a table lookup.
401    pub fn try_new(
402        dtype: DType,
403        symbols: Buffer<Symbol>,
404        symbol_lengths: Buffer<u8>,
405        codes: VarBinArray,
406        uncompressed_lengths: ArrayRef,
407    ) -> VortexResult<Self> {
408        // Check: symbols must not have length > MAX_CODE
409        if symbols.len() > 255 {
410            vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
411        }
412        if symbols.len() != symbol_lengths.len() {
413            vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
414        }
415
416        if uncompressed_lengths.len() != codes.len() {
417            vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
418        }
419
420        if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
421            vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
422        }
423
424        // Check: strings must be a Binary array.
425        if !matches!(codes.dtype(), DType::Binary(_)) {
426            vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
427        }
428
429        // SAFETY: all components validated above
430        unsafe {
431            Ok(Self::new_unchecked(
432                dtype,
433                symbols,
434                symbol_lengths,
435                codes,
436                uncompressed_lengths,
437            ))
438        }
439    }
440
441    pub(crate) unsafe fn new_unchecked(
442        dtype: DType,
443        symbols: Buffer<Symbol>,
444        symbol_lengths: Buffer<u8>,
445        codes: VarBinArray,
446        uncompressed_lengths: ArrayRef,
447    ) -> Self {
448        let symbols2 = symbols.clone();
449        let symbol_lengths2 = symbol_lengths.clone();
450        let compressor = Arc::new(LazyLock::new(Box::new(move || {
451            Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
452        })
453            as Box<dyn Fn() -> Compressor + Send>));
454        let codes_array = codes.to_array();
455
456        Self {
457            dtype,
458            symbols,
459            symbol_lengths,
460            codes,
461            codes_array,
462            uncompressed_lengths,
463            stats_set: Default::default(),
464            compressor,
465        }
466    }
467
468    /// Access the symbol table array
469    pub fn symbols(&self) -> &Buffer<Symbol> {
470        &self.symbols
471    }
472
473    /// Access the symbol table array
474    pub fn symbol_lengths(&self) -> &Buffer<u8> {
475        &self.symbol_lengths
476    }
477
478    /// Access the codes array
479    pub fn codes(&self) -> &VarBinArray {
480        &self.codes
481    }
482
483    /// Get the DType of the codes array
484    #[inline]
485    pub fn codes_dtype(&self) -> &DType {
486        self.codes.dtype()
487    }
488
489    /// Get the uncompressed length for each element in the array.
490    pub fn uncompressed_lengths(&self) -> &ArrayRef {
491        &self.uncompressed_lengths
492    }
493
494    /// Get the DType of the uncompressed lengths array
495    #[inline]
496    pub fn uncompressed_lengths_dtype(&self) -> &DType {
497        self.uncompressed_lengths.dtype()
498    }
499
500    /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from
501    /// this array.
502    pub fn decompressor(&self) -> Decompressor<'_> {
503        Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
504    }
505
506    /// Retrieves the FSST compressor.
507    pub fn compressor(&self) -> &Compressor {
508        self.compressor.as_ref()
509    }
510}
511
512impl ValidityChild<FSSTVTable> for FSSTVTable {
513    fn validity_child(array: &FSSTArray) -> &ArrayRef {
514        &array.codes_array
515    }
516}
517
518#[cfg(test)]
519mod test {
520    use fsst::Compressor;
521    use fsst::Symbol;
522    use vortex_array::Array;
523    use vortex_array::IntoArray;
524    use vortex_array::LEGACY_SESSION;
525    use vortex_array::ProstMetadata;
526    use vortex_array::VortexSessionExecute;
527    use vortex_array::accessor::ArrayAccessor;
528    use vortex_array::arrays::VarBinViewArray;
529    use vortex_array::buffer::BufferHandle;
530    use vortex_array::dtype::DType;
531    use vortex_array::dtype::Nullability;
532    use vortex_array::dtype::PType;
533    use vortex_array::test_harness::check_metadata;
534    use vortex_array::vtable::VTable;
535    use vortex_buffer::Buffer;
536    use vortex_error::VortexError;
537
538    use crate::FSSTVTable;
539    use crate::array::FSSTMetadata;
540    use crate::fsst_compress_iter;
541
542    #[cfg_attr(miri, ignore)]
543    #[test]
544    fn test_fsst_metadata() {
545        check_metadata(
546            "fsst.metadata",
547            ProstMetadata(FSSTMetadata {
548                uncompressed_lengths_ptype: PType::U64 as i32,
549                codes_offsets_ptype: PType::I32 as i32,
550            }),
551        );
552    }
553
554    /// The original FSST array stored codes as a VarBinArray child and required that the child
555    /// have this encoding. Vortex forbids this kind of introspection, therefore we had to fix
556    /// the array to store the compressed offsets and compressed data buffer separately, and only
557    /// use VarBinArray to delegate behavior.
558    ///
559    /// This test manually constructs an old-style FSST array and ensures that it can still be
560    /// deserialized.
561    #[test]
562    fn test_back_compat() {
563        let symbols = Buffer::<Symbol>::copy_from([
564            Symbol::from_slice(b"abc00000"),
565            Symbol::from_slice(b"defghijk"),
566        ]);
567        let symbol_lengths = Buffer::<u8>::copy_from([3, 8]);
568
569        let compressor = Compressor::rebuild_from(symbols.as_slice(), symbol_lengths.as_slice());
570        let fsst_array = fsst_compress_iter(
571            [Some(b"abcabcab".as_ref()), Some(b"defghijk".as_ref())].into_iter(),
572            2,
573            DType::Utf8(Nullability::NonNullable),
574            &compressor,
575        );
576
577        let compressed_codes = fsst_array.codes().clone();
578
579        // There were two buffers:
580        // 1. The 8 byte symbols
581        // 2. The symbol lengths as u8.
582        let buffers = [
583            BufferHandle::new_host(symbols.into_byte_buffer()),
584            BufferHandle::new_host(symbol_lengths.into_byte_buffer()),
585        ];
586
587        // There were 2 children:
588        // 1. The compressed codes, stored as a VarBinArray.
589        // 2. The uncompressed lengths, stored as a Primitive array.
590        let children = vec![
591            compressed_codes.into_array(),
592            fsst_array.uncompressed_lengths().clone(),
593        ];
594
595        let fsst = FSSTVTable::build(
596            &DType::Utf8(Nullability::NonNullable),
597            2,
598            &ProstMetadata(FSSTMetadata {
599                uncompressed_lengths_ptype: fsst_array
600                    .uncompressed_lengths()
601                    .dtype()
602                    .as_ptype()
603                    .into(),
604                // Legacy array did not store this field, use Protobuf default of 0.
605                codes_offsets_ptype: 0,
606            }),
607            &buffers,
608            &children.as_slice(),
609        )
610        .unwrap();
611
612        let decompressed = fsst
613            .into_array()
614            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
615            .unwrap();
616        decompressed
617            .with_iterator(|it| {
618                assert_eq!(it.next().unwrap(), Some(b"abcabcab".as_ref()));
619                assert_eq!(it.next().unwrap(), Some(b"defghijk".as_ref()));
620                Ok::<_, VortexError>(())
621            })
622            .unwrap()
623    }
624}