Skip to main content

vortex_fsst/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::sync::LazyLock;
9
10use fsst::Compressor;
11use fsst::Decompressor;
12use fsst::Symbol;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayRef;
16use vortex_array::Canonical;
17use vortex_array::DeserializeMetadata;
18use vortex_array::DynArray;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionResult;
21use vortex_array::IntoArray;
22use vortex_array::Precision;
23use vortex_array::ProstMetadata;
24use vortex_array::SerializeMetadata;
25use vortex_array::arrays::VarBin;
26use vortex_array::arrays::VarBinArray;
27use vortex_array::buffer::BufferHandle;
28use vortex_array::builders::ArrayBuilder;
29use vortex_array::builders::VarBinViewBuilder;
30use vortex_array::dtype::DType;
31use vortex_array::dtype::Nullability;
32use vortex_array::dtype::PType;
33use vortex_array::serde::ArrayChildren;
34use vortex_array::stats::ArrayStats;
35use vortex_array::stats::StatsSetRef;
36use vortex_array::validity::Validity;
37use vortex_array::vtable;
38use vortex_array::vtable::Array;
39use vortex_array::vtable::ArrayId;
40use vortex_array::vtable::VTable;
41use vortex_array::vtable::ValidityChild;
42use vortex_array::vtable::ValidityHelper;
43use vortex_array::vtable::ValidityVTableFromChild;
44use vortex_array::vtable::validity_nchildren;
45use vortex_array::vtable::validity_to_child;
46use vortex_buffer::Buffer;
47use vortex_buffer::ByteBuffer;
48use vortex_error::VortexResult;
49use vortex_error::vortex_bail;
50use vortex_error::vortex_ensure;
51use vortex_error::vortex_err;
52use vortex_error::vortex_panic;
53use vortex_session::VortexSession;
54
55use crate::canonical::canonicalize_fsst;
56use crate::canonical::fsst_decode_views;
57use crate::kernel::PARENT_KERNELS;
58use crate::rules::RULES;
59
60vtable!(FSST);
61
62#[derive(Clone, prost::Message)]
63pub struct FSSTMetadata {
64    #[prost(enumeration = "PType", tag = "1")]
65    uncompressed_lengths_ptype: i32,
66
67    #[prost(enumeration = "PType", tag = "2")]
68    codes_offsets_ptype: i32,
69}
70
71impl FSSTMetadata {
72    pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
73        PType::try_from(self.uncompressed_lengths_ptype)
74            .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
75    }
76}
77
78impl VTable for FSST {
79    type Array = FSSTArray;
80
81    type Metadata = ProstMetadata<FSSTMetadata>;
82    type OperationsVTable = Self;
83    type ValidityVTable = ValidityVTableFromChild;
84
85    fn vtable(_array: &Self::Array) -> &Self {
86        &FSST
87    }
88
89    fn id(&self) -> ArrayId {
90        Self::ID
91    }
92
93    fn len(array: &FSSTArray) -> usize {
94        array.codes().len()
95    }
96
97    fn dtype(array: &FSSTArray) -> &DType {
98        &array.dtype
99    }
100
101    fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
102        array.stats_set.to_ref(array.as_ref())
103    }
104
105    fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
106        array.dtype.hash(state);
107        array.symbols.array_hash(state, precision);
108        array.symbol_lengths.array_hash(state, precision);
109        array.codes.as_ref().array_hash(state, precision);
110        array.uncompressed_lengths.array_hash(state, precision);
111    }
112
113    fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
114        array.dtype == other.dtype
115            && array.symbols.array_eq(&other.symbols, precision)
116            && array
117                .symbol_lengths
118                .array_eq(&other.symbol_lengths, precision)
119            && array
120                .codes
121                .as_ref()
122                .array_eq(other.codes.as_ref(), precision)
123            && array
124                .uncompressed_lengths
125                .array_eq(&other.uncompressed_lengths, precision)
126    }
127
128    fn nbuffers(_array: &FSSTArray) -> usize {
129        3
130    }
131
132    fn buffer(array: &FSSTArray, idx: usize) -> BufferHandle {
133        match idx {
134            0 => BufferHandle::new_host(array.symbols().clone().into_byte_buffer()),
135            1 => BufferHandle::new_host(array.symbol_lengths().clone().into_byte_buffer()),
136            2 => array.codes.bytes_handle().clone(),
137            _ => vortex_panic!("FSSTArray buffer index {idx} out of bounds"),
138        }
139    }
140
141    fn buffer_name(_array: &FSSTArray, idx: usize) -> Option<String> {
142        match idx {
143            0 => Some("symbols".to_string()),
144            1 => Some("symbol_lengths".to_string()),
145            2 => Some("compressed_codes".to_string()),
146            _ => vortex_panic!("FSSTArray buffer_name index {idx} out of bounds"),
147        }
148    }
149
150    fn nchildren(array: &FSSTArray) -> usize {
151        2 + validity_nchildren(array.codes.validity())
152    }
153
154    fn child(array: &FSSTArray, idx: usize) -> ArrayRef {
155        match idx {
156            0 => array.uncompressed_lengths().clone(),
157            1 => array.codes.offsets().clone(),
158            2 => validity_to_child(array.codes.validity(), array.codes.len())
159                .unwrap_or_else(|| vortex_panic!("FSSTArray child index {idx} out of bounds")),
160            _ => vortex_panic!("FSSTArray child index {idx} out of bounds"),
161        }
162    }
163
164    fn child_name(_array: &FSSTArray, idx: usize) -> String {
165        match idx {
166            0 => "uncompressed_lengths".to_string(),
167            1 => "codes_offsets".to_string(),
168            2 => "validity".to_string(),
169            _ => vortex_panic!("FSSTArray child_name index {idx} out of bounds"),
170        }
171    }
172
173    fn metadata(array: &FSSTArray) -> VortexResult<Self::Metadata> {
174        Ok(ProstMetadata(FSSTMetadata {
175            uncompressed_lengths_ptype: array.uncompressed_lengths().dtype().as_ptype().into(),
176            codes_offsets_ptype: array.codes.offsets().dtype().as_ptype().into(),
177        }))
178    }
179
180    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
181        Ok(Some(metadata.serialize()))
182    }
183
184    fn deserialize(
185        bytes: &[u8],
186        _dtype: &DType,
187        _len: usize,
188        _buffers: &[BufferHandle],
189        _session: &VortexSession,
190    ) -> VortexResult<Self::Metadata> {
191        Ok(ProstMetadata(
192            <ProstMetadata<FSSTMetadata> as DeserializeMetadata>::deserialize(bytes)?,
193        ))
194    }
195
196    fn append_to_builder(
197        array: &FSSTArray,
198        builder: &mut dyn ArrayBuilder,
199        ctx: &mut ExecutionCtx,
200    ) -> VortexResult<()> {
201        let Some(builder) = builder.as_any_mut().downcast_mut::<VarBinViewBuilder>() else {
202            builder.extend_from_array(
203                &array
204                    .clone()
205                    .into_array()
206                    .execute::<Canonical>(ctx)?
207                    .into_array(),
208            );
209            return Ok(());
210        };
211
212        // Decompress the whole block of data into a new buffer, and create some views
213        // from it instead.
214        let (buffers, views) = fsst_decode_views(array, builder.completed_block_count(), ctx)?;
215
216        builder.push_buffer_and_adjusted_views(&buffers, &views, array.validity_mask()?);
217        Ok(())
218    }
219
220    fn build(
221        dtype: &DType,
222        len: usize,
223        metadata: &Self::Metadata,
224        buffers: &[BufferHandle],
225        children: &dyn ArrayChildren,
226    ) -> VortexResult<FSSTArray> {
227        let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_host_sync()?);
228        let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_host_sync()?);
229
230        // Check for the legacy deserialization path.
231        if buffers.len() == 2 {
232            if children.len() != 2 {
233                vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
234            }
235            let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
236            let codes = codes
237                .as_opt::<VarBin>()
238                .ok_or_else(|| {
239                    vortex_err!(
240                        "Expected VarBinArray for codes, got {}",
241                        codes.encoding_id()
242                    )
243                })?
244                .clone();
245            let uncompressed_lengths = children.get(
246                1,
247                &DType::Primitive(
248                    metadata.0.get_uncompressed_lengths_ptype()?,
249                    Nullability::NonNullable,
250                ),
251                len,
252            )?;
253
254            return FSSTArray::try_new(
255                dtype.clone(),
256                symbols,
257                symbol_lengths,
258                codes,
259                uncompressed_lengths,
260            );
261        }
262
263        // Check for the current deserialization path.
264        if buffers.len() == 3 {
265            let uncompressed_lengths = children.get(
266                0,
267                &DType::Primitive(
268                    metadata.0.get_uncompressed_lengths_ptype()?,
269                    Nullability::NonNullable,
270                ),
271                len,
272            )?;
273
274            let codes_buffer = ByteBuffer::from_byte_buffer(buffers[2].clone().try_to_host_sync()?);
275            let codes_offsets = children.get(
276                1,
277                &DType::Primitive(
278                    PType::try_from(metadata.codes_offsets_ptype)?,
279                    Nullability::NonNullable,
280                ),
281                // VarBin offsets are len + 1
282                len + 1,
283            )?;
284
285            let codes_validity = if children.len() == 2 {
286                Validity::from(dtype.nullability())
287            } else if children.len() == 3 {
288                let validity = children.get(2, &Validity::DTYPE, len)?;
289                Validity::Array(validity)
290            } else {
291                vortex_bail!("Expected 0 or 1 child, got {}", children.len());
292            };
293
294            let codes = VarBinArray::try_new(
295                codes_offsets,
296                codes_buffer,
297                DType::Binary(dtype.nullability()),
298                codes_validity,
299            )?;
300
301            return FSSTArray::try_new(
302                dtype.clone(),
303                symbols,
304                symbol_lengths,
305                codes,
306                uncompressed_lengths,
307            );
308        }
309
310        vortex_bail!(
311            "InvalidArgument: Expected 2 or 3 buffers, got {}",
312            buffers.len()
313        );
314    }
315
316    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
317        vortex_ensure!(
318            children.len() == 2,
319            "FSSTArray expects 2 children, got {}",
320            children.len()
321        );
322
323        let mut children_iter = children.into_iter();
324        let codes = children_iter
325            .next()
326            .ok_or_else(|| vortex_err!("FSSTArray with_children missing codes"))?;
327
328        let codes = codes
329            .as_opt::<VarBin>()
330            .ok_or_else(|| {
331                vortex_err!(
332                    "Expected VarBinArray for codes, got {}",
333                    codes.encoding_id()
334                )
335            })?
336            .clone();
337        let uncompressed_lengths = children_iter
338            .next()
339            .ok_or_else(|| vortex_err!("FSSTArray with_children missing uncompressed_lengths"))?;
340
341        array.codes = codes;
342        array.uncompressed_lengths = uncompressed_lengths;
343
344        Ok(())
345    }
346
347    fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
348        canonicalize_fsst(&array, ctx).map(ExecutionResult::done)
349    }
350
351    fn execute_parent(
352        array: &Array<Self>,
353        parent: &ArrayRef,
354        child_idx: usize,
355        ctx: &mut ExecutionCtx,
356    ) -> VortexResult<Option<ArrayRef>> {
357        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
358    }
359
360    fn reduce_parent(
361        array: &Array<Self>,
362        parent: &ArrayRef,
363        child_idx: usize,
364    ) -> VortexResult<Option<ArrayRef>> {
365        RULES.evaluate(array, parent, child_idx)
366    }
367}
368
369#[derive(Clone)]
370pub struct FSSTArray {
371    dtype: DType,
372    symbols: Buffer<Symbol>,
373    symbol_lengths: Buffer<u8>,
374    codes: VarBinArray,
375    /// NOTE(ngates): this === codes, but is stored as an ArrayRef so we can return &ArrayRef!
376    codes_array: ArrayRef,
377    /// Lengths of the original values before compression, can be compressed.
378    uncompressed_lengths: ArrayRef,
379    stats_set: ArrayStats,
380
381    /// Memoized compressor used for push-down of compute by compressing the RHS.
382    compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
383}
384
385impl Debug for FSSTArray {
386    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
387        f.debug_struct("FSSTArray")
388            .field("dtype", &self.dtype)
389            .field("symbols", &self.symbols)
390            .field("symbol_lengths", &self.symbol_lengths)
391            .field("codes", &self.codes)
392            .field("uncompressed_lengths", &self.uncompressed_lengths)
393            .finish()
394    }
395}
396
397#[derive(Clone, Debug)]
398pub struct FSST;
399
400impl FSST {
401    pub const ID: ArrayId = ArrayId::new_ref("vortex.fsst");
402}
403
404impl FSSTArray {
405    /// Build an FSST array from a set of `symbols` and `codes`.
406    ///
407    /// Symbols are 8-bytes and can represent short strings, each of which is assigned
408    /// a code.
409    ///
410    /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes.
411    /// Each code corresponds either to a symbol, or to the "escape code",
412    /// which tells the decoder to emit the following byte without doing a table lookup.
413    pub fn try_new(
414        dtype: DType,
415        symbols: Buffer<Symbol>,
416        symbol_lengths: Buffer<u8>,
417        codes: VarBinArray,
418        uncompressed_lengths: ArrayRef,
419    ) -> VortexResult<Self> {
420        // Check: symbols must not have length > MAX_CODE
421        if symbols.len() > 255 {
422            vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
423        }
424        if symbols.len() != symbol_lengths.len() {
425            vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
426        }
427
428        if uncompressed_lengths.len() != codes.len() {
429            vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
430        }
431
432        if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
433            vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
434        }
435
436        // Check: strings must be a Binary array.
437        if !matches!(codes.dtype(), DType::Binary(_)) {
438            vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
439        }
440
441        // SAFETY: all components validated above
442        unsafe {
443            Ok(Self::new_unchecked(
444                dtype,
445                symbols,
446                symbol_lengths,
447                codes,
448                uncompressed_lengths,
449            ))
450        }
451    }
452
453    pub(crate) unsafe fn new_unchecked(
454        dtype: DType,
455        symbols: Buffer<Symbol>,
456        symbol_lengths: Buffer<u8>,
457        codes: VarBinArray,
458        uncompressed_lengths: ArrayRef,
459    ) -> Self {
460        let symbols2 = symbols.clone();
461        let symbol_lengths2 = symbol_lengths.clone();
462        let compressor = Arc::new(LazyLock::new(Box::new(move || {
463            Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
464        })
465            as Box<dyn Fn() -> Compressor + Send>));
466        let codes_array = codes.clone().into_array();
467
468        Self {
469            dtype,
470            symbols,
471            symbol_lengths,
472            codes,
473            codes_array,
474            uncompressed_lengths,
475            stats_set: Default::default(),
476            compressor,
477        }
478    }
479
480    /// Access the symbol table array
481    pub fn symbols(&self) -> &Buffer<Symbol> {
482        &self.symbols
483    }
484
485    /// Access the symbol table array
486    pub fn symbol_lengths(&self) -> &Buffer<u8> {
487        &self.symbol_lengths
488    }
489
490    /// Access the codes array
491    pub fn codes(&self) -> &VarBinArray {
492        &self.codes
493    }
494
495    /// Get the DType of the codes array
496    #[inline]
497    pub fn codes_dtype(&self) -> &DType {
498        self.codes.dtype()
499    }
500
501    /// Get the uncompressed length for each element in the array.
502    pub fn uncompressed_lengths(&self) -> &ArrayRef {
503        &self.uncompressed_lengths
504    }
505
506    /// Get the DType of the uncompressed lengths array
507    #[inline]
508    pub fn uncompressed_lengths_dtype(&self) -> &DType {
509        self.uncompressed_lengths.dtype()
510    }
511
512    /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from
513    /// this array.
514    pub fn decompressor(&self) -> Decompressor<'_> {
515        Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
516    }
517
518    /// Retrieves the FSST compressor.
519    pub fn compressor(&self) -> &Compressor {
520        self.compressor.as_ref()
521    }
522}
523
524impl ValidityChild<FSST> for FSST {
525    fn validity_child(array: &FSSTArray) -> &ArrayRef {
526        &array.codes_array
527    }
528}
529
530#[cfg(test)]
531mod test {
532    use fsst::Compressor;
533    use fsst::Symbol;
534    use vortex_array::DynArray;
535    use vortex_array::IntoArray;
536    use vortex_array::LEGACY_SESSION;
537    use vortex_array::ProstMetadata;
538    use vortex_array::VortexSessionExecute;
539    use vortex_array::accessor::ArrayAccessor;
540    use vortex_array::arrays::VarBinViewArray;
541    use vortex_array::buffer::BufferHandle;
542    use vortex_array::dtype::DType;
543    use vortex_array::dtype::Nullability;
544    use vortex_array::dtype::PType;
545    use vortex_array::test_harness::check_metadata;
546    use vortex_array::vtable::VTable;
547    use vortex_buffer::Buffer;
548    use vortex_error::VortexError;
549
550    use crate::FSST;
551    use crate::array::FSSTMetadata;
552    use crate::fsst_compress_iter;
553
554    #[cfg_attr(miri, ignore)]
555    #[test]
556    fn test_fsst_metadata() {
557        check_metadata(
558            "fsst.metadata",
559            ProstMetadata(FSSTMetadata {
560                uncompressed_lengths_ptype: PType::U64 as i32,
561                codes_offsets_ptype: PType::I32 as i32,
562            }),
563        );
564    }
565
566    /// The original FSST array stored codes as a VarBinArray child and required that the child
567    /// have this encoding. Vortex forbids this kind of introspection, therefore we had to fix
568    /// the array to store the compressed offsets and compressed data buffer separately, and only
569    /// use VarBinArray to delegate behavior.
570    ///
571    /// This test manually constructs an old-style FSST array and ensures that it can still be
572    /// deserialized.
573    #[test]
574    fn test_back_compat() {
575        let symbols = Buffer::<Symbol>::copy_from([
576            Symbol::from_slice(b"abc00000"),
577            Symbol::from_slice(b"defghijk"),
578        ]);
579        let symbol_lengths = Buffer::<u8>::copy_from([3, 8]);
580
581        let compressor = Compressor::rebuild_from(symbols.as_slice(), symbol_lengths.as_slice());
582        let fsst_array = fsst_compress_iter(
583            [Some(b"abcabcab".as_ref()), Some(b"defghijk".as_ref())].into_iter(),
584            2,
585            DType::Utf8(Nullability::NonNullable),
586            &compressor,
587        );
588
589        let compressed_codes = fsst_array.codes().clone();
590
591        // There were two buffers:
592        // 1. The 8 byte symbols
593        // 2. The symbol lengths as u8.
594        let buffers = [
595            BufferHandle::new_host(symbols.into_byte_buffer()),
596            BufferHandle::new_host(symbol_lengths.into_byte_buffer()),
597        ];
598
599        // There were 2 children:
600        // 1. The compressed codes, stored as a VarBinArray.
601        // 2. The uncompressed lengths, stored as a Primitive array.
602        let children = vec![
603            compressed_codes.into_array(),
604            fsst_array.uncompressed_lengths().clone(),
605        ];
606
607        let fsst = FSST::build(
608            &DType::Utf8(Nullability::NonNullable),
609            2,
610            &ProstMetadata(FSSTMetadata {
611                uncompressed_lengths_ptype: fsst_array
612                    .uncompressed_lengths()
613                    .dtype()
614                    .as_ptype()
615                    .into(),
616                // Legacy array did not store this field, use Protobuf default of 0.
617                codes_offsets_ptype: 0,
618            }),
619            &buffers,
620            &children.as_slice(),
621        )
622        .unwrap();
623
624        let decompressed = fsst
625            .into_array()
626            .execute::<VarBinViewArray>(&mut LEGACY_SESSION.create_execution_ctx())
627            .unwrap();
628        decompressed
629            .with_iterator(|it| {
630                assert_eq!(it.next().unwrap(), Some(b"abcabcab".as_ref()));
631                assert_eq!(it.next().unwrap(), Some(b"defghijk".as_ref()));
632                Ok::<_, VortexError>(())
633            })
634            .unwrap()
635    }
636}