Skip to main content

vortex_fsst/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::sync::LazyLock;
9
10use fsst::Compressor;
11use fsst::Decompressor;
12use fsst::Symbol;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayRef;
16use vortex_array::Canonical;
17use vortex_array::DeserializeMetadata;
18use vortex_array::DynArray;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionResult;
21use vortex_array::IntoArray;
22use vortex_array::Precision;
23use vortex_array::ProstMetadata;
24use vortex_array::SerializeMetadata;
25use vortex_array::arrays::VarBin;
26use vortex_array::arrays::VarBinArray;
27use vortex_array::buffer::BufferHandle;
28use vortex_array::builders::ArrayBuilder;
29use vortex_array::builders::VarBinViewBuilder;
30use vortex_array::dtype::DType;
31use vortex_array::dtype::Nullability;
32use vortex_array::dtype::PType;
33use vortex_array::serde::ArrayChildren;
34use vortex_array::stats::ArrayStats;
35use vortex_array::stats::StatsSetRef;
36use vortex_array::validity::Validity;
37use vortex_array::vtable;
38use vortex_array::vtable::ArrayId;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityChild;
41use vortex_array::vtable::ValidityHelper;
42use vortex_array::vtable::ValidityVTableFromChild;
43use vortex_array::vtable::validity_nchildren;
44use vortex_array::vtable::validity_to_child;
45use vortex_buffer::Buffer;
46use vortex_buffer::ByteBuffer;
47use vortex_error::VortexResult;
48use vortex_error::vortex_bail;
49use vortex_error::vortex_ensure;
50use vortex_error::vortex_err;
51use vortex_error::vortex_panic;
52use vortex_session::VortexSession;
53
54use crate::canonical::canonicalize_fsst;
55use crate::canonical::fsst_decode_views;
56use crate::kernel::PARENT_KERNELS;
57use crate::rules::RULES;
58
59vtable!(FSST);
60
61#[derive(Clone, prost::Message)]
62pub struct FSSTMetadata {
63    #[prost(enumeration = "PType", tag = "1")]
64    uncompressed_lengths_ptype: i32,
65
66    #[prost(enumeration = "PType", tag = "2")]
67    codes_offsets_ptype: i32,
68}
69
70impl FSSTMetadata {
71    pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
72        PType::try_from(self.uncompressed_lengths_ptype)
73            .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
74    }
75}
76
77impl VTable for FSST {
78    type Array = FSSTArray;
79
80    type Metadata = ProstMetadata<FSSTMetadata>;
81    type OperationsVTable = Self;
82    type ValidityVTable = ValidityVTableFromChild;
83
84    fn vtable(_array: &Self::Array) -> &Self {
85        &FSST
86    }
87
88    fn id(&self) -> ArrayId {
89        Self::ID
90    }
91
92    fn len(array: &FSSTArray) -> usize {
93        array.codes().len()
94    }
95
96    fn dtype(array: &FSSTArray) -> &DType {
97        &array.dtype
98    }
99
100    fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
101        array.stats_set.to_ref(array.as_ref())
102    }
103
104    fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
105        array.dtype.hash(state);
106        array.symbols.array_hash(state, precision);
107        array.symbol_lengths.array_hash(state, precision);
108        array.codes.as_ref().array_hash(state, precision);
109        array.uncompressed_lengths.array_hash(state, precision);
110    }
111
112    fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
113        array.dtype == other.dtype
114            && array.symbols.array_eq(&other.symbols, precision)
115            && array
116                .symbol_lengths
117                .array_eq(&other.symbol_lengths, precision)
118            && array
119                .codes
120                .as_ref()
121                .array_eq(other.codes.as_ref(), precision)
122            && array
123                .uncompressed_lengths
124                .array_eq(&other.uncompressed_lengths, precision)
125    }
126
127    fn nbuffers(_array: &FSSTArray) -> usize {
128        3
129    }
130
131    fn buffer(array: &FSSTArray, idx: usize) -> BufferHandle {
132        match idx {
133            0 => BufferHandle::new_host(array.symbols().clone().into_byte_buffer()),
134            1 => BufferHandle::new_host(array.symbol_lengths().clone().into_byte_buffer()),
135            2 => array.codes.bytes_handle().clone(),
136            _ => vortex_panic!("FSSTArray buffer index {idx} out of bounds"),
137        }
138    }
139
140    fn buffer_name(_array: &FSSTArray, idx: usize) -> Option<String> {
141        match idx {
142            0 => Some("symbols".to_string()),
143            1 => Some("symbol_lengths".to_string()),
144            2 => Some("compressed_codes".to_string()),
145            _ => vortex_panic!("FSSTArray buffer_name index {idx} out of bounds"),
146        }
147    }
148
149    fn nchildren(array: &FSSTArray) -> usize {
150        2 + validity_nchildren(array.codes.validity())
151    }
152
153    fn child(array: &FSSTArray, idx: usize) -> ArrayRef {
154        match idx {
155            0 => array.uncompressed_lengths().clone(),
156            1 => array.codes.offsets().clone(),
157            2 => validity_to_child(array.codes.validity(), array.codes.len())
158                .unwrap_or_else(|| vortex_panic!("FSSTArray child index {idx} out of bounds")),
159            _ => vortex_panic!("FSSTArray child index {idx} out of bounds"),
160        }
161    }
162
163    fn child_name(_array: &FSSTArray, idx: usize) -> String {
164        match idx {
165            0 => "uncompressed_lengths".to_string(),
166            1 => "codes_offsets".to_string(),
167            2 => "validity".to_string(),
168            _ => vortex_panic!("FSSTArray child_name index {idx} out of bounds"),
169        }
170    }
171
172    fn metadata(array: &FSSTArray) -> VortexResult<Self::Metadata> {
173        Ok(ProstMetadata(FSSTMetadata {
174            uncompressed_lengths_ptype: array.uncompressed_lengths().dtype().as_ptype().into(),
175            codes_offsets_ptype: array.codes.offsets().dtype().as_ptype().into(),
176        }))
177    }
178
179    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
180        Ok(Some(metadata.serialize()))
181    }
182
183    fn deserialize(
184        bytes: &[u8],
185        _dtype: &DType,
186        _len: usize,
187        _buffers: &[BufferHandle],
188        _session: &VortexSession,
189    ) -> VortexResult<Self::Metadata> {
190        Ok(ProstMetadata(
191            <ProstMetadata<FSSTMetadata> as DeserializeMetadata>::deserialize(bytes)?,
192        ))
193    }
194
195    fn append_to_builder(
196        array: &FSSTArray,
197        builder: &mut dyn ArrayBuilder,
198        ctx: &mut ExecutionCtx,
199    ) -> VortexResult<()> {
200        let Some(builder) = builder.as_any_mut().downcast_mut::<VarBinViewBuilder>() else {
201            builder.extend_from_array(
202                &array
203                    .clone()
204                    .into_array()
205                    .execute::<Canonical>(ctx)?
206                    .into_array(),
207            );
208            return Ok(());
209        };
210
211        // Decompress the whole block of data into a new buffer, and create some views
212        // from it instead.
213        let (buffers, views) = fsst_decode_views(array, builder.completed_block_count(), ctx)?;
214
215        builder.push_buffer_and_adjusted_views(&buffers, &views, array.validity_mask()?);
216        Ok(())
217    }
218
219    fn build(
220        dtype: &DType,
221        len: usize,
222        metadata: &Self::Metadata,
223        buffers: &[BufferHandle],
224        children: &dyn ArrayChildren,
225    ) -> VortexResult<FSSTArray> {
226        let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_host_sync()?);
227        let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_host_sync()?);
228
229        // Check for the legacy deserialization path.
230        if buffers.len() == 2 {
231            if children.len() != 2 {
232                vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
233            }
234            let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
235            let codes = codes
236                .as_opt::<VarBin>()
237                .ok_or_else(|| {
238                    vortex_err!(
239                        "Expected VarBinArray for codes, got {}",
240                        codes.encoding_id()
241                    )
242                })?
243                .clone();
244            let uncompressed_lengths = children.get(
245                1,
246                &DType::Primitive(
247                    metadata.0.get_uncompressed_lengths_ptype()?,
248                    Nullability::NonNullable,
249                ),
250                len,
251            )?;
252
253            return FSSTArray::try_new(
254                dtype.clone(),
255                symbols,
256                symbol_lengths,
257                codes,
258                uncompressed_lengths,
259            );
260        }
261
262        // Check for the current deserialization path.
263        if buffers.len() == 3 {
264            let uncompressed_lengths = children.get(
265                0,
266                &DType::Primitive(
267                    metadata.0.get_uncompressed_lengths_ptype()?,
268                    Nullability::NonNullable,
269                ),
270                len,
271            )?;
272
273            let codes_buffer = ByteBuffer::from_byte_buffer(buffers[2].clone().try_to_host_sync()?);
274            let codes_offsets = children.get(
275                1,
276                &DType::Primitive(
277                    PType::try_from(metadata.codes_offsets_ptype)?,
278                    Nullability::NonNullable,
279                ),
280                // VarBin offsets are len + 1
281                len + 1,
282            )?;
283
284            let codes_validity = if children.len() == 2 {
285                Validity::from(dtype.nullability())
286            } else if children.len() == 3 {
287                let validity = children.get(2, &Validity::DTYPE, len)?;
288                Validity::Array(validity)
289            } else {
290                vortex_bail!("Expected 0 or 1 child, got {}", children.len());
291            };
292
293            let codes = VarBinArray::try_new(
294                codes_offsets,
295                codes_buffer,
296                DType::Binary(dtype.nullability()),
297                codes_validity,
298            )?;
299
300            return FSSTArray::try_new(
301                dtype.clone(),
302                symbols,
303                symbol_lengths,
304                codes,
305                uncompressed_lengths,
306            );
307        }
308
309        vortex_bail!(
310            "InvalidArgument: Expected 2 or 3 buffers, got {}",
311            buffers.len()
312        );
313    }
314
315    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
316        vortex_ensure!(
317            children.len() == 2,
318            "FSSTArray expects 2 children, got {}",
319            children.len()
320        );
321
322        let mut children_iter = children.into_iter();
323        let codes = children_iter
324            .next()
325            .ok_or_else(|| vortex_err!("FSSTArray with_children missing codes"))?;
326
327        let codes = codes
328            .as_opt::<VarBin>()
329            .ok_or_else(|| {
330                vortex_err!(
331                    "Expected VarBinArray for codes, got {}",
332                    codes.encoding_id()
333                )
334            })?
335            .clone();
336        let uncompressed_lengths = children_iter
337            .next()
338            .ok_or_else(|| vortex_err!("FSSTArray with_children missing uncompressed_lengths"))?;
339
340        array.codes = codes;
341        array.uncompressed_lengths = uncompressed_lengths;
342
343        Ok(())
344    }
345
346    fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
347        canonicalize_fsst(&array, ctx).map(ExecutionResult::done)
348    }
349
350    fn execute_parent(
351        array: &Self::Array,
352        parent: &ArrayRef,
353        child_idx: usize,
354        ctx: &mut ExecutionCtx,
355    ) -> VortexResult<Option<ArrayRef>> {
356        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
357    }
358
359    fn reduce_parent(
360        array: &Self::Array,
361        parent: &ArrayRef,
362        child_idx: usize,
363    ) -> VortexResult<Option<ArrayRef>> {
364        RULES.evaluate(array, parent, child_idx)
365    }
366}
367
368#[derive(Clone)]
369pub struct FSSTArray {
370    dtype: DType,
371    symbols: Buffer<Symbol>,
372    symbol_lengths: Buffer<u8>,
373    codes: VarBinArray,
374    /// NOTE(ngates): this === codes, but is stored as an ArrayRef so we can return &ArrayRef!
375    codes_array: ArrayRef,
376    /// Lengths of the original values before compression, can be compressed.
377    uncompressed_lengths: ArrayRef,
378    stats_set: ArrayStats,
379
380    /// Memoized compressor used for push-down of compute by compressing the RHS.
381    compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
382}
383
384impl Debug for FSSTArray {
385    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
386        f.debug_struct("FSSTArray")
387            .field("dtype", &self.dtype)
388            .field("symbols", &self.symbols)
389            .field("symbol_lengths", &self.symbol_lengths)
390            .field("codes", &self.codes)
391            .field("uncompressed_lengths", &self.uncompressed_lengths)
392            .finish()
393    }
394}
395
396#[derive(Clone, Debug)]
397pub struct FSST;
398
399impl FSST {
400    pub const ID: ArrayId = ArrayId::new_ref("vortex.fsst");
401}
402
403impl FSSTArray {
404    /// Build an FSST array from a set of `symbols` and `codes`.
405    ///
406    /// Symbols are 8-bytes and can represent short strings, each of which is assigned
407    /// a code.
408    ///
409    /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes.
410    /// Each code corresponds either to a symbol, or to the "escape code",
411    /// which tells the decoder to emit the following byte without doing a table lookup.
412    pub fn try_new(
413        dtype: DType,
414        symbols: Buffer<Symbol>,
415        symbol_lengths: Buffer<u8>,
416        codes: VarBinArray,
417        uncompressed_lengths: ArrayRef,
418    ) -> VortexResult<Self> {
419        // Check: symbols must not have length > MAX_CODE
420        if symbols.len() > 255 {
421            vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
422        }
423        if symbols.len() != symbol_lengths.len() {
424            vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
425        }
426
427        if uncompressed_lengths.len() != codes.len() {
428            vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
429        }
430
431        if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
432            vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
433        }
434
435        // Check: strings must be a Binary array.
436        if !matches!(codes.dtype(), DType::Binary(_)) {
437            vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
438        }
439
440        // SAFETY: all components validated above
441        unsafe {
442            Ok(Self::new_unchecked(
443                dtype,
444                symbols,
445                symbol_lengths,
446                codes,
447                uncompressed_lengths,
448            ))
449        }
450    }
451
452    pub(crate) unsafe fn new_unchecked(
453        dtype: DType,
454        symbols: Buffer<Symbol>,
455        symbol_lengths: Buffer<u8>,
456        codes: VarBinArray,
457        uncompressed_lengths: ArrayRef,
458    ) -> Self {
459        let symbols2 = symbols.clone();
460        let symbol_lengths2 = symbol_lengths.clone();
461        let compressor = Arc::new(LazyLock::new(Box::new(move || {
462            Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
463        })
464            as Box<dyn Fn() -> Compressor + Send>));
465        let codes_array = codes.clone().into_array();
466
467        Self {
468            dtype,
469            symbols,
470            symbol_lengths,
471            codes,
472            codes_array,
473            uncompressed_lengths,
474            stats_set: Default::default(),
475            compressor,
476        }
477    }
478
479    /// Access the symbol table array
480    pub fn symbols(&self) -> &Buffer<Symbol> {
481        &self.symbols
482    }
483
484    /// Access the symbol table array
485    pub fn symbol_lengths(&self) -> &Buffer<u8> {
486        &self.symbol_lengths
487    }
488
489    /// Access the codes array
490    pub fn codes(&self) -> &VarBinArray {
491        &self.codes
492    }
493
494    /// Get the DType of the codes array
495    #[inline]
496    pub fn codes_dtype(&self) -> &DType {
497        self.codes.dtype()
498    }
499
500    /// Get the uncompressed length for each element in the array.
501    pub fn uncompressed_lengths(&self) -> &ArrayRef {
502        &self.uncompressed_lengths
503    }
504
505    /// Get the DType of the uncompressed lengths array
506    #[inline]
507    pub fn uncompressed_lengths_dtype(&self) -> &DType {
508        self.uncompressed_lengths.dtype()
509    }
510
511    /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from
512    /// this array.
513    pub fn decompressor(&self) -> Decompressor<'_> {
514        Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
515    }
516
517    /// Retrieves the FSST compressor.
518    pub fn compressor(&self) -> &Compressor {
519        self.compressor.as_ref()
520    }
521}
522
523impl ValidityChild<FSST> for FSST {
524    fn validity_child(array: &FSSTArray) -> &ArrayRef {
525        &array.codes_array
526    }
527}
528
529#[cfg(test)]
530mod test {
531    use fsst::Compressor;
532    use fsst::Symbol;
533    use vortex_array::DynArray;
534    use vortex_array::IntoArray;
535    use vortex_array::LEGACY_SESSION;
536    use vortex_array::ProstMetadata;
537    use vortex_array::VortexSessionExecute;
538    use vortex_array::accessor::ArrayAccessor;
539    use vortex_array::arrays::VarBinViewArray;
540    use vortex_array::buffer::BufferHandle;
541    use vortex_array::dtype::DType;
542    use vortex_array::dtype::Nullability;
543    use vortex_array::dtype::PType;
544    use vortex_array::test_harness::check_metadata;
545    use vortex_array::vtable::VTable;
546    use vortex_buffer::Buffer;
547    use vortex_error::VortexError;
548
549    use crate::FSST;
550    use crate::array::FSSTMetadata;
551    use crate::fsst_compress_iter;
552
553    #[cfg_attr(miri, ignore)]
554    #[test]
555    fn test_fsst_metadata() {
556        check_metadata(
557            "fsst.metadata",
558            ProstMetadata(FSSTMetadata {
559                uncompressed_lengths_ptype: PType::U64 as i32,
560                codes_offsets_ptype: PType::I32 as i32,
561            }),
562        );
563    }
564
565    /// The original FSST array stored codes as a VarBinArray child and required that the child
566    /// have this encoding. Vortex forbids this kind of introspection, therefore we had to fix
567    /// the array to store the compressed offsets and compressed data buffer separately, and only
568    /// use VarBinArray to delegate behavior.
569    ///
570    /// This test manually constructs an old-style FSST array and ensures that it can still be
571    /// deserialized.
572    #[test]
573    fn test_back_compat() {
574        let symbols = Buffer::<Symbol>::copy_from([
575            Symbol::from_slice(b"abc00000"),
576            Symbol::from_slice(b"defghijk"),
577        ]);
578        let symbol_lengths = Buffer::<u8>::copy_from([3, 8]);
579
580        let compressor = Compressor::rebuild_from(symbols.as_slice(), symbol_lengths.as_slice());
581        let fsst_array = fsst_compress_iter(
582            [Some(b"abcabcab".as_ref()), Some(b"defghijk".as_ref())].into_iter(),
583            2,
584            DType::Utf8(Nullability::NonNullable),
585            &compressor,
586        );
587
588        let compressed_codes = fsst_array.codes().clone();
589
590        // There were two buffers:
591        // 1. The 8 byte symbols
592        // 2. The symbol lengths as u8.
593        let buffers = [
594            BufferHandle::new_host(symbols.into_byte_buffer()),
595            BufferHandle::new_host(symbol_lengths.into_byte_buffer()),
596        ];
597
598        // There were 2 children:
599        // 1. The compressed codes, stored as a VarBinArray.
600        // 2. The uncompressed lengths, stored as a Primitive array.
601        let children = vec![
602            compressed_codes.into_array(),
603            fsst_array.uncompressed_lengths().clone(),
604        ];
605
606        let fsst = FSST::build(
607            &DType::Utf8(Nullability::NonNullable),
608            2,
609            &ProstMetadata(FSSTMetadata {
610                uncompressed_lengths_ptype: fsst_array
611                    .uncompressed_lengths()
612                    .dtype()
613                    .as_ptype()
614                    .into(),
615                // Legacy array did not store this field, use Protobuf default of 0.
616                codes_offsets_ptype: 0,
617            }),
618            &buffers,
619            &children.as_slice(),
620        )
621        .unwrap();
622
623        let decompressed = fsst
624            .into_array()
625            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
626            .unwrap();
627        decompressed
628            .with_iterator(|it| {
629                assert_eq!(it.next().unwrap(), Some(b"abcabcab".as_ref()));
630                assert_eq!(it.next().unwrap(), Some(b"defghijk".as_ref()));
631                Ok::<_, VortexError>(())
632            })
633            .unwrap()
634    }
635}