Skip to main content

vortex_fsst/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::sync::LazyLock;
9
10use fsst::Compressor;
11use fsst::Decompressor;
12use fsst::Symbol;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayRef;
16use vortex_array::Canonical;
17use vortex_array::DeserializeMetadata;
18use vortex_array::DynArray;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionStep;
21use vortex_array::IntoArray;
22use vortex_array::Precision;
23use vortex_array::ProstMetadata;
24use vortex_array::SerializeMetadata;
25use vortex_array::arrays::VarBinArray;
26use vortex_array::arrays::VarBinVTable;
27use vortex_array::buffer::BufferHandle;
28use vortex_array::builders::ArrayBuilder;
29use vortex_array::builders::VarBinViewBuilder;
30use vortex_array::dtype::DType;
31use vortex_array::dtype::Nullability;
32use vortex_array::dtype::PType;
33use vortex_array::serde::ArrayChildren;
34use vortex_array::stats::ArrayStats;
35use vortex_array::stats::StatsSetRef;
36use vortex_array::validity::Validity;
37use vortex_array::vtable;
38use vortex_array::vtable::ArrayId;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityChild;
41use vortex_array::vtable::ValidityHelper;
42use vortex_array::vtable::ValidityVTableFromChild;
43use vortex_array::vtable::validity_nchildren;
44use vortex_array::vtable::validity_to_child;
45use vortex_buffer::Buffer;
46use vortex_buffer::ByteBuffer;
47use vortex_error::VortexResult;
48use vortex_error::vortex_bail;
49use vortex_error::vortex_ensure;
50use vortex_error::vortex_err;
51use vortex_error::vortex_panic;
52use vortex_session::VortexSession;
53
54use crate::canonical::canonicalize_fsst;
55use crate::canonical::fsst_decode_views;
56use crate::kernel::PARENT_KERNELS;
57use crate::rules::RULES;
58
59vtable!(FSST);
60
61#[derive(Clone, prost::Message)]
62pub struct FSSTMetadata {
63    #[prost(enumeration = "PType", tag = "1")]
64    uncompressed_lengths_ptype: i32,
65
66    #[prost(enumeration = "PType", tag = "2")]
67    codes_offsets_ptype: i32,
68}
69
70impl FSSTMetadata {
71    pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
72        PType::try_from(self.uncompressed_lengths_ptype)
73            .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
74    }
75}
76
77impl VTable for FSSTVTable {
78    type Array = FSSTArray;
79
80    type Metadata = ProstMetadata<FSSTMetadata>;
81    type OperationsVTable = Self;
82    type ValidityVTable = ValidityVTableFromChild;
83
84    fn id(_array: &Self::Array) -> ArrayId {
85        Self::ID
86    }
87
88    fn len(array: &FSSTArray) -> usize {
89        array.codes().len()
90    }
91
92    fn dtype(array: &FSSTArray) -> &DType {
93        &array.dtype
94    }
95
96    fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
97        array.stats_set.to_ref(array.as_ref())
98    }
99
100    fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
101        array.dtype.hash(state);
102        array.symbols.array_hash(state, precision);
103        array.symbol_lengths.array_hash(state, precision);
104        array.codes.as_ref().array_hash(state, precision);
105        array.uncompressed_lengths.array_hash(state, precision);
106    }
107
108    fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
109        array.dtype == other.dtype
110            && array.symbols.array_eq(&other.symbols, precision)
111            && array
112                .symbol_lengths
113                .array_eq(&other.symbol_lengths, precision)
114            && array
115                .codes
116                .as_ref()
117                .array_eq(other.codes.as_ref(), precision)
118            && array
119                .uncompressed_lengths
120                .array_eq(&other.uncompressed_lengths, precision)
121    }
122
123    fn nbuffers(_array: &FSSTArray) -> usize {
124        3
125    }
126
127    fn buffer(array: &FSSTArray, idx: usize) -> BufferHandle {
128        match idx {
129            0 => BufferHandle::new_host(array.symbols().clone().into_byte_buffer()),
130            1 => BufferHandle::new_host(array.symbol_lengths().clone().into_byte_buffer()),
131            2 => array.codes.bytes_handle().clone(),
132            _ => vortex_panic!("FSSTArray buffer index {idx} out of bounds"),
133        }
134    }
135
136    fn buffer_name(_array: &FSSTArray, idx: usize) -> Option<String> {
137        match idx {
138            0 => Some("symbols".to_string()),
139            1 => Some("symbol_lengths".to_string()),
140            2 => Some("compressed_codes".to_string()),
141            _ => vortex_panic!("FSSTArray buffer_name index {idx} out of bounds"),
142        }
143    }
144
145    fn nchildren(array: &FSSTArray) -> usize {
146        2 + validity_nchildren(array.codes.validity())
147    }
148
149    fn child(array: &FSSTArray, idx: usize) -> ArrayRef {
150        match idx {
151            0 => array.uncompressed_lengths().clone(),
152            1 => array.codes.offsets().clone(),
153            2 => validity_to_child(array.codes.validity(), array.codes.len())
154                .unwrap_or_else(|| vortex_panic!("FSSTArray child index {idx} out of bounds")),
155            _ => vortex_panic!("FSSTArray child index {idx} out of bounds"),
156        }
157    }
158
159    fn child_name(_array: &FSSTArray, idx: usize) -> String {
160        match idx {
161            0 => "uncompressed_lengths".to_string(),
162            1 => "codes_offsets".to_string(),
163            2 => "validity".to_string(),
164            _ => vortex_panic!("FSSTArray child_name index {idx} out of bounds"),
165        }
166    }
167
168    fn metadata(array: &FSSTArray) -> VortexResult<Self::Metadata> {
169        Ok(ProstMetadata(FSSTMetadata {
170            uncompressed_lengths_ptype: array.uncompressed_lengths().dtype().as_ptype().into(),
171            codes_offsets_ptype: array.codes.offsets().dtype().as_ptype().into(),
172        }))
173    }
174
175    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
176        Ok(Some(metadata.serialize()))
177    }
178
179    fn deserialize(
180        bytes: &[u8],
181        _dtype: &DType,
182        _len: usize,
183        _buffers: &[BufferHandle],
184        _session: &VortexSession,
185    ) -> VortexResult<Self::Metadata> {
186        Ok(ProstMetadata(
187            <ProstMetadata<FSSTMetadata> as DeserializeMetadata>::deserialize(bytes)?,
188        ))
189    }
190
191    fn append_to_builder(
192        array: &FSSTArray,
193        builder: &mut dyn ArrayBuilder,
194        ctx: &mut ExecutionCtx,
195    ) -> VortexResult<()> {
196        let Some(builder) = builder.as_any_mut().downcast_mut::<VarBinViewBuilder>() else {
197            builder.extend_from_array(
198                &array
199                    .clone()
200                    .into_array()
201                    .execute::<Canonical>(ctx)?
202                    .into_array(),
203            );
204            return Ok(());
205        };
206
207        // Decompress the whole block of data into a new buffer, and create some views
208        // from it instead.
209        let (buffers, views) = fsst_decode_views(array, builder.completed_block_count(), ctx)?;
210
211        builder.push_buffer_and_adjusted_views(&buffers, &views, array.validity_mask()?);
212        Ok(())
213    }
214
215    fn build(
216        dtype: &DType,
217        len: usize,
218        metadata: &Self::Metadata,
219        buffers: &[BufferHandle],
220        children: &dyn ArrayChildren,
221    ) -> VortexResult<FSSTArray> {
222        let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_host_sync()?);
223        let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_host_sync()?);
224
225        // Check for the legacy deserialization path.
226        if buffers.len() == 2 {
227            if children.len() != 2 {
228                vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
229            }
230            let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
231            let codes = codes
232                .as_opt::<VarBinVTable>()
233                .ok_or_else(|| {
234                    vortex_err!(
235                        "Expected VarBinArray for codes, got {}",
236                        codes.encoding_id()
237                    )
238                })?
239                .clone();
240            let uncompressed_lengths = children.get(
241                1,
242                &DType::Primitive(
243                    metadata.0.get_uncompressed_lengths_ptype()?,
244                    Nullability::NonNullable,
245                ),
246                len,
247            )?;
248
249            return FSSTArray::try_new(
250                dtype.clone(),
251                symbols,
252                symbol_lengths,
253                codes,
254                uncompressed_lengths,
255            );
256        }
257
258        // Check for the current deserialization path.
259        if buffers.len() == 3 {
260            let uncompressed_lengths = children.get(
261                0,
262                &DType::Primitive(
263                    metadata.0.get_uncompressed_lengths_ptype()?,
264                    Nullability::NonNullable,
265                ),
266                len,
267            )?;
268
269            let codes_buffer = ByteBuffer::from_byte_buffer(buffers[2].clone().try_to_host_sync()?);
270            let codes_offsets = children.get(
271                1,
272                &DType::Primitive(
273                    PType::try_from(metadata.codes_offsets_ptype)?,
274                    Nullability::NonNullable,
275                ),
276                // VarBin offsets are len + 1
277                len + 1,
278            )?;
279
280            let codes_validity = if children.len() == 2 {
281                Validity::from(dtype.nullability())
282            } else if children.len() == 3 {
283                let validity = children.get(2, &Validity::DTYPE, len)?;
284                Validity::Array(validity)
285            } else {
286                vortex_bail!("Expected 0 or 1 child, got {}", children.len());
287            };
288
289            let codes = VarBinArray::try_new(
290                codes_offsets,
291                codes_buffer,
292                DType::Binary(dtype.nullability()),
293                codes_validity,
294            )?;
295
296            return FSSTArray::try_new(
297                dtype.clone(),
298                symbols,
299                symbol_lengths,
300                codes,
301                uncompressed_lengths,
302            );
303        }
304
305        vortex_bail!(
306            "InvalidArgument: Expected 2 or 3 buffers, got {}",
307            buffers.len()
308        );
309    }
310
311    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
312        vortex_ensure!(
313            children.len() == 2,
314            "FSSTArray expects 2 children, got {}",
315            children.len()
316        );
317
318        let mut children_iter = children.into_iter();
319        let codes = children_iter
320            .next()
321            .ok_or_else(|| vortex_err!("FSSTArray with_children missing codes"))?;
322
323        let codes = codes
324            .as_opt::<VarBinVTable>()
325            .ok_or_else(|| {
326                vortex_err!(
327                    "Expected VarBinArray for codes, got {}",
328                    codes.encoding_id()
329                )
330            })?
331            .clone();
332        let uncompressed_lengths = children_iter
333            .next()
334            .ok_or_else(|| vortex_err!("FSSTArray with_children missing uncompressed_lengths"))?;
335
336        array.codes = codes;
337        array.uncompressed_lengths = uncompressed_lengths;
338
339        Ok(())
340    }
341
342    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
343        canonicalize_fsst(array, ctx).map(ExecutionStep::Done)
344    }
345
346    fn execute_parent(
347        array: &Self::Array,
348        parent: &ArrayRef,
349        child_idx: usize,
350        ctx: &mut ExecutionCtx,
351    ) -> VortexResult<Option<ArrayRef>> {
352        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
353    }
354
355    fn reduce_parent(
356        array: &Self::Array,
357        parent: &ArrayRef,
358        child_idx: usize,
359    ) -> VortexResult<Option<ArrayRef>> {
360        RULES.evaluate(array, parent, child_idx)
361    }
362}
363
364#[derive(Clone)]
365pub struct FSSTArray {
366    dtype: DType,
367    symbols: Buffer<Symbol>,
368    symbol_lengths: Buffer<u8>,
369    codes: VarBinArray,
370    /// NOTE(ngates): this === codes, but is stored as an ArrayRef so we can return &ArrayRef!
371    codes_array: ArrayRef,
372    /// Lengths of the original values before compression, can be compressed.
373    uncompressed_lengths: ArrayRef,
374    stats_set: ArrayStats,
375
376    /// Memoized compressor used for push-down of compute by compressing the RHS.
377    compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
378}
379
380impl Debug for FSSTArray {
381    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
382        f.debug_struct("FSSTArray")
383            .field("dtype", &self.dtype)
384            .field("symbols", &self.symbols)
385            .field("symbol_lengths", &self.symbol_lengths)
386            .field("codes", &self.codes)
387            .field("uncompressed_lengths", &self.uncompressed_lengths)
388            .finish()
389    }
390}
391
392#[derive(Debug)]
393pub struct FSSTVTable;
394
395impl FSSTVTable {
396    pub const ID: ArrayId = ArrayId::new_ref("vortex.fsst");
397}
398
399impl FSSTArray {
400    /// Build an FSST array from a set of `symbols` and `codes`.
401    ///
402    /// Symbols are 8-bytes and can represent short strings, each of which is assigned
403    /// a code.
404    ///
405    /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes.
406    /// Each code corresponds either to a symbol, or to the "escape code",
407    /// which tells the decoder to emit the following byte without doing a table lookup.
408    pub fn try_new(
409        dtype: DType,
410        symbols: Buffer<Symbol>,
411        symbol_lengths: Buffer<u8>,
412        codes: VarBinArray,
413        uncompressed_lengths: ArrayRef,
414    ) -> VortexResult<Self> {
415        // Check: symbols must not have length > MAX_CODE
416        if symbols.len() > 255 {
417            vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
418        }
419        if symbols.len() != symbol_lengths.len() {
420            vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
421        }
422
423        if uncompressed_lengths.len() != codes.len() {
424            vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
425        }
426
427        if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
428            vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
429        }
430
431        // Check: strings must be a Binary array.
432        if !matches!(codes.dtype(), DType::Binary(_)) {
433            vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
434        }
435
436        // SAFETY: all components validated above
437        unsafe {
438            Ok(Self::new_unchecked(
439                dtype,
440                symbols,
441                symbol_lengths,
442                codes,
443                uncompressed_lengths,
444            ))
445        }
446    }
447
448    pub(crate) unsafe fn new_unchecked(
449        dtype: DType,
450        symbols: Buffer<Symbol>,
451        symbol_lengths: Buffer<u8>,
452        codes: VarBinArray,
453        uncompressed_lengths: ArrayRef,
454    ) -> Self {
455        let symbols2 = symbols.clone();
456        let symbol_lengths2 = symbol_lengths.clone();
457        let compressor = Arc::new(LazyLock::new(Box::new(move || {
458            Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
459        })
460            as Box<dyn Fn() -> Compressor + Send>));
461        let codes_array = codes.clone().into_array();
462
463        Self {
464            dtype,
465            symbols,
466            symbol_lengths,
467            codes,
468            codes_array,
469            uncompressed_lengths,
470            stats_set: Default::default(),
471            compressor,
472        }
473    }
474
475    /// Access the symbol table array
476    pub fn symbols(&self) -> &Buffer<Symbol> {
477        &self.symbols
478    }
479
480    /// Access the symbol table array
481    pub fn symbol_lengths(&self) -> &Buffer<u8> {
482        &self.symbol_lengths
483    }
484
485    /// Access the codes array
486    pub fn codes(&self) -> &VarBinArray {
487        &self.codes
488    }
489
490    /// Get the DType of the codes array
491    #[inline]
492    pub fn codes_dtype(&self) -> &DType {
493        self.codes.dtype()
494    }
495
496    /// Get the uncompressed length for each element in the array.
497    pub fn uncompressed_lengths(&self) -> &ArrayRef {
498        &self.uncompressed_lengths
499    }
500
501    /// Get the DType of the uncompressed lengths array
502    #[inline]
503    pub fn uncompressed_lengths_dtype(&self) -> &DType {
504        self.uncompressed_lengths.dtype()
505    }
506
507    /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from
508    /// this array.
509    pub fn decompressor(&self) -> Decompressor<'_> {
510        Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
511    }
512
513    /// Retrieves the FSST compressor.
514    pub fn compressor(&self) -> &Compressor {
515        self.compressor.as_ref()
516    }
517}
518
519impl ValidityChild<FSSTVTable> for FSSTVTable {
520    fn validity_child(array: &FSSTArray) -> &ArrayRef {
521        &array.codes_array
522    }
523}
524
525#[cfg(test)]
526mod test {
527    use fsst::Compressor;
528    use fsst::Symbol;
529    use vortex_array::DynArray;
530    use vortex_array::IntoArray;
531    use vortex_array::LEGACY_SESSION;
532    use vortex_array::ProstMetadata;
533    use vortex_array::VortexSessionExecute;
534    use vortex_array::accessor::ArrayAccessor;
535    use vortex_array::arrays::VarBinViewArray;
536    use vortex_array::buffer::BufferHandle;
537    use vortex_array::dtype::DType;
538    use vortex_array::dtype::Nullability;
539    use vortex_array::dtype::PType;
540    use vortex_array::test_harness::check_metadata;
541    use vortex_array::vtable::VTable;
542    use vortex_buffer::Buffer;
543    use vortex_error::VortexError;
544
545    use crate::FSSTVTable;
546    use crate::array::FSSTMetadata;
547    use crate::fsst_compress_iter;
548
549    #[cfg_attr(miri, ignore)]
550    #[test]
551    fn test_fsst_metadata() {
552        check_metadata(
553            "fsst.metadata",
554            ProstMetadata(FSSTMetadata {
555                uncompressed_lengths_ptype: PType::U64 as i32,
556                codes_offsets_ptype: PType::I32 as i32,
557            }),
558        );
559    }
560
561    /// The original FSST array stored codes as a VarBinArray child and required that the child
562    /// have this encoding. Vortex forbids this kind of introspection, therefore we had to fix
563    /// the array to store the compressed offsets and compressed data buffer separately, and only
564    /// use VarBinArray to delegate behavior.
565    ///
566    /// This test manually constructs an old-style FSST array and ensures that it can still be
567    /// deserialized.
568    #[test]
569    fn test_back_compat() {
570        let symbols = Buffer::<Symbol>::copy_from([
571            Symbol::from_slice(b"abc00000"),
572            Symbol::from_slice(b"defghijk"),
573        ]);
574        let symbol_lengths = Buffer::<u8>::copy_from([3, 8]);
575
576        let compressor = Compressor::rebuild_from(symbols.as_slice(), symbol_lengths.as_slice());
577        let fsst_array = fsst_compress_iter(
578            [Some(b"abcabcab".as_ref()), Some(b"defghijk".as_ref())].into_iter(),
579            2,
580            DType::Utf8(Nullability::NonNullable),
581            &compressor,
582        );
583
584        let compressed_codes = fsst_array.codes().clone();
585
586        // There were two buffers:
587        // 1. The 8 byte symbols
588        // 2. The symbol lengths as u8.
589        let buffers = [
590            BufferHandle::new_host(symbols.into_byte_buffer()),
591            BufferHandle::new_host(symbol_lengths.into_byte_buffer()),
592        ];
593
594        // There were 2 children:
595        // 1. The compressed codes, stored as a VarBinArray.
596        // 2. The uncompressed lengths, stored as a Primitive array.
597        let children = vec![
598            compressed_codes.into_array(),
599            fsst_array.uncompressed_lengths().clone(),
600        ];
601
602        let fsst = FSSTVTable::build(
603            &DType::Utf8(Nullability::NonNullable),
604            2,
605            &ProstMetadata(FSSTMetadata {
606                uncompressed_lengths_ptype: fsst_array
607                    .uncompressed_lengths()
608                    .dtype()
609                    .as_ptype()
610                    .into(),
611                // Legacy array did not store this field, use Protobuf default of 0.
612                codes_offsets_ptype: 0,
613            }),
614            &buffers,
615            &children.as_slice(),
616        )
617        .unwrap();
618
619        let decompressed = fsst
620            .into_array()
621            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
622            .unwrap();
623        decompressed
624            .with_iterator(|it| {
625                assert_eq!(it.next().unwrap(), Some(b"abcabcab".as_ref()));
626                assert_eq!(it.next().unwrap(), Some(b"defghijk".as_ref()));
627                Ok::<_, VortexError>(())
628            })
629            .unwrap()
630    }
631}