vortex_fsst/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6use std::sync::{Arc, LazyLock};
7
8use fsst::{Compressor, Decompressor, Symbol};
9use vortex_array::arrays::VarBinArray;
10use vortex_array::stats::{ArrayStats, StatsSetRef};
11use vortex_array::vtable::{
12    ArrayVTable, NotSupported, VTable, ValidityChild, ValidityVTableFromChild,
13};
14use vortex_array::{
15    Array, ArrayEq, ArrayHash, ArrayRef, EncodingId, EncodingRef, Precision, vtable,
16};
17use vortex_buffer::Buffer;
18use vortex_dtype::DType;
19use vortex_error::{VortexResult, vortex_bail};
20
21vtable!(FSST);
22
23impl VTable for FSSTVTable {
24    type Array = FSSTArray;
25    type Encoding = FSSTEncoding;
26
27    type ArrayVTable = Self;
28    type CanonicalVTable = Self;
29    type OperationsVTable = Self;
30    type ValidityVTable = ValidityVTableFromChild;
31    type VisitorVTable = Self;
32    type ComputeVTable = NotSupported;
33    type EncodeVTable = Self;
34    type SerdeVTable = Self;
35    type OperatorVTable = NotSupported;
36
37    fn id(_encoding: &Self::Encoding) -> EncodingId {
38        EncodingId::new_ref("vortex.fsst")
39    }
40
41    fn encoding(_array: &Self::Array) -> EncodingRef {
42        EncodingRef::new_ref(FSSTEncoding.as_ref())
43    }
44}
45
46#[derive(Clone)]
47pub struct FSSTArray {
48    dtype: DType,
49    symbols: Buffer<Symbol>,
50    symbol_lengths: Buffer<u8>,
51    codes: VarBinArray,
52    /// Lengths of the original values before compression, can be compressed.
53    uncompressed_lengths: ArrayRef,
54    stats_set: ArrayStats,
55
56    /// Memoized compressor used for push-down of compute by compressing the RHS.
57    compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
58}
59
60impl Debug for FSSTArray {
61    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("FSSTArray")
63            .field("dtype", &self.dtype)
64            .field("symbols", &self.symbols)
65            .field("symbol_lengths", &self.symbol_lengths)
66            .field("codes", &self.codes)
67            .field("uncompressed_lengths", &self.uncompressed_lengths)
68            .finish()
69    }
70}
71
72#[derive(Clone, Debug)]
73pub struct FSSTEncoding;
74
75impl FSSTArray {
76    /// Build an FSST array from a set of `symbols` and `codes`.
77    ///
78    /// Symbols are 8-bytes and can represent short strings, each of which is assigned
79    /// a code.
80    ///
81    /// The `codes` array is a Binary array where each binary datum is a sequence of 8-bit codes.
82    /// Each code corresponds either to a symbol, or to the "escape code",
83    /// which tells the decoder to emit the following byte without doing a table lookup.
84    pub fn try_new(
85        dtype: DType,
86        symbols: Buffer<Symbol>,
87        symbol_lengths: Buffer<u8>,
88        codes: VarBinArray,
89        uncompressed_lengths: ArrayRef,
90    ) -> VortexResult<Self> {
91        // Check: symbols must not have length > MAX_CODE
92        if symbols.len() > 255 {
93            vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
94        }
95        if symbols.len() != symbol_lengths.len() {
96            vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
97        }
98
99        if uncompressed_lengths.len() != codes.len() {
100            vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
101        }
102
103        if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
104            vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
105        }
106
107        // Check: strings must be a Binary array.
108        if !matches!(codes.dtype(), DType::Binary(_)) {
109            vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
110        }
111
112        // SAFETY: all components validated above
113        unsafe {
114            Ok(Self::new_unchecked(
115                dtype,
116                symbols,
117                symbol_lengths,
118                codes,
119                uncompressed_lengths,
120            ))
121        }
122    }
123
124    pub(crate) unsafe fn new_unchecked(
125        dtype: DType,
126        symbols: Buffer<Symbol>,
127        symbol_lengths: Buffer<u8>,
128        codes: VarBinArray,
129        uncompressed_lengths: ArrayRef,
130    ) -> Self {
131        let symbols2 = symbols.clone();
132        let symbol_lengths2 = symbol_lengths.clone();
133        let compressor = Arc::new(LazyLock::new(Box::new(move || {
134            Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
135        })
136            as Box<dyn Fn() -> Compressor + Send>));
137
138        Self {
139            dtype,
140            symbols,
141            symbol_lengths,
142            codes,
143            uncompressed_lengths,
144            stats_set: Default::default(),
145            compressor,
146        }
147    }
148
149    /// Access the symbol table array
150    pub fn symbols(&self) -> &Buffer<Symbol> {
151        &self.symbols
152    }
153
154    /// Access the symbol table array
155    pub fn symbol_lengths(&self) -> &Buffer<u8> {
156        &self.symbol_lengths
157    }
158
159    /// Access the codes array
160    pub fn codes(&self) -> &VarBinArray {
161        &self.codes
162    }
163
164    /// Get the DType of the codes array
165    #[inline]
166    pub fn codes_dtype(&self) -> &DType {
167        self.codes.dtype()
168    }
169
170    /// Get the uncompressed length for each element in the array.
171    pub fn uncompressed_lengths(&self) -> &ArrayRef {
172        &self.uncompressed_lengths
173    }
174
175    /// Get the DType of the uncompressed lengths array
176    #[inline]
177    pub fn uncompressed_lengths_dtype(&self) -> &DType {
178        self.uncompressed_lengths.dtype()
179    }
180
181    /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from
182    /// this array.
183    ///
184    /// This is private to the crate to avoid leaking `fsst-rs` types as part of the public API.
185    pub(crate) fn decompressor(&self) -> Decompressor<'_> {
186        Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
187    }
188
189    pub(crate) fn compressor(&self) -> &Compressor {
190        self.compressor.as_ref()
191    }
192}
193
194impl ArrayVTable<FSSTVTable> for FSSTVTable {
195    fn len(array: &FSSTArray) -> usize {
196        array.codes().len()
197    }
198
199    fn dtype(array: &FSSTArray) -> &DType {
200        &array.dtype
201    }
202
203    fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
204        array.stats_set.to_ref(array.as_ref())
205    }
206
207    fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
208        array.dtype.hash(state);
209        array.symbols.array_hash(state, precision);
210        array.symbol_lengths.array_hash(state, precision);
211        array.codes.as_ref().array_hash(state, precision);
212        array.uncompressed_lengths.array_hash(state, precision);
213    }
214
215    fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
216        array.dtype == other.dtype
217            && array.symbols.array_eq(&other.symbols, precision)
218            && array
219                .symbol_lengths
220                .array_eq(&other.symbol_lengths, precision)
221            && array
222                .codes
223                .as_ref()
224                .array_eq(other.codes.as_ref(), precision)
225            && array
226                .uncompressed_lengths
227                .array_eq(&other.uncompressed_lengths, precision)
228    }
229}
230
231impl ValidityChild<FSSTVTable> for FSSTVTable {
232    fn validity_child(array: &FSSTArray) -> &dyn Array {
233        array.codes().as_ref()
234    }
235}