1use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6use std::sync::{Arc, LazyLock};
7
8use fsst::{Compressor, Decompressor, Symbol};
9use vortex_array::arrays::VarBinArray;
10use vortex_array::stats::{ArrayStats, StatsSetRef};
11use vortex_array::vtable::{
12 ArrayVTable, NotSupported, VTable, ValidityChild, ValidityVTableFromChild,
13};
14use vortex_array::{
15 Array, ArrayEq, ArrayHash, ArrayRef, EncodingId, EncodingRef, Precision, vtable,
16};
17use vortex_buffer::Buffer;
18use vortex_dtype::DType;
19use vortex_error::{VortexResult, vortex_bail};
20
21vtable!(FSST);
22
23impl VTable for FSSTVTable {
24 type Array = FSSTArray;
25 type Encoding = FSSTEncoding;
26
27 type ArrayVTable = Self;
28 type CanonicalVTable = Self;
29 type OperationsVTable = Self;
30 type ValidityVTable = ValidityVTableFromChild;
31 type VisitorVTable = Self;
32 type ComputeVTable = NotSupported;
33 type EncodeVTable = Self;
34 type SerdeVTable = Self;
35 type OperatorVTable = NotSupported;
36
37 fn id(_encoding: &Self::Encoding) -> EncodingId {
38 EncodingId::new_ref("vortex.fsst")
39 }
40
41 fn encoding(_array: &Self::Array) -> EncodingRef {
42 EncodingRef::new_ref(FSSTEncoding.as_ref())
43 }
44}
45
46#[derive(Clone)]
47pub struct FSSTArray {
48 dtype: DType,
49 symbols: Buffer<Symbol>,
50 symbol_lengths: Buffer<u8>,
51 codes: VarBinArray,
52 uncompressed_lengths: ArrayRef,
54 stats_set: ArrayStats,
55
56 compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
58}
59
60impl Debug for FSSTArray {
61 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("FSSTArray")
63 .field("dtype", &self.dtype)
64 .field("symbols", &self.symbols)
65 .field("symbol_lengths", &self.symbol_lengths)
66 .field("codes", &self.codes)
67 .field("uncompressed_lengths", &self.uncompressed_lengths)
68 .finish()
69 }
70}
71
72#[derive(Clone, Debug)]
73pub struct FSSTEncoding;
74
75impl FSSTArray {
76 pub fn try_new(
85 dtype: DType,
86 symbols: Buffer<Symbol>,
87 symbol_lengths: Buffer<u8>,
88 codes: VarBinArray,
89 uncompressed_lengths: ArrayRef,
90 ) -> VortexResult<Self> {
91 if symbols.len() > 255 {
93 vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
94 }
95 if symbols.len() != symbol_lengths.len() {
96 vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
97 }
98
99 if uncompressed_lengths.len() != codes.len() {
100 vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
101 }
102
103 if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
104 vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
105 }
106
107 if !matches!(codes.dtype(), DType::Binary(_)) {
109 vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
110 }
111
112 unsafe {
114 Ok(Self::new_unchecked(
115 dtype,
116 symbols,
117 symbol_lengths,
118 codes,
119 uncompressed_lengths,
120 ))
121 }
122 }
123
124 pub(crate) unsafe fn new_unchecked(
125 dtype: DType,
126 symbols: Buffer<Symbol>,
127 symbol_lengths: Buffer<u8>,
128 codes: VarBinArray,
129 uncompressed_lengths: ArrayRef,
130 ) -> Self {
131 let symbols2 = symbols.clone();
132 let symbol_lengths2 = symbol_lengths.clone();
133 let compressor = Arc::new(LazyLock::new(Box::new(move || {
134 Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
135 })
136 as Box<dyn Fn() -> Compressor + Send>));
137
138 Self {
139 dtype,
140 symbols,
141 symbol_lengths,
142 codes,
143 uncompressed_lengths,
144 stats_set: Default::default(),
145 compressor,
146 }
147 }
148
149 pub fn symbols(&self) -> &Buffer<Symbol> {
151 &self.symbols
152 }
153
154 pub fn symbol_lengths(&self) -> &Buffer<u8> {
156 &self.symbol_lengths
157 }
158
159 pub fn codes(&self) -> &VarBinArray {
161 &self.codes
162 }
163
164 #[inline]
166 pub fn codes_dtype(&self) -> &DType {
167 self.codes.dtype()
168 }
169
170 pub fn uncompressed_lengths(&self) -> &ArrayRef {
172 &self.uncompressed_lengths
173 }
174
175 #[inline]
177 pub fn uncompressed_lengths_dtype(&self) -> &DType {
178 self.uncompressed_lengths.dtype()
179 }
180
181 pub(crate) fn decompressor(&self) -> Decompressor<'_> {
186 Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
187 }
188
189 pub(crate) fn compressor(&self) -> &Compressor {
190 self.compressor.as_ref()
191 }
192}
193
194impl ArrayVTable<FSSTVTable> for FSSTVTable {
195 fn len(array: &FSSTArray) -> usize {
196 array.codes().len()
197 }
198
199 fn dtype(array: &FSSTArray) -> &DType {
200 &array.dtype
201 }
202
203 fn stats(array: &FSSTArray) -> StatsSetRef<'_> {
204 array.stats_set.to_ref(array.as_ref())
205 }
206
207 fn array_hash<H: std::hash::Hasher>(array: &FSSTArray, state: &mut H, precision: Precision) {
208 array.dtype.hash(state);
209 array.symbols.array_hash(state, precision);
210 array.symbol_lengths.array_hash(state, precision);
211 array.codes.as_ref().array_hash(state, precision);
212 array.uncompressed_lengths.array_hash(state, precision);
213 }
214
215 fn array_eq(array: &FSSTArray, other: &FSSTArray, precision: Precision) -> bool {
216 array.dtype == other.dtype
217 && array.symbols.array_eq(&other.symbols, precision)
218 && array
219 .symbol_lengths
220 .array_eq(&other.symbol_lengths, precision)
221 && array
222 .codes
223 .as_ref()
224 .array_eq(other.codes.as_ref(), precision)
225 && array
226 .uncompressed_lengths
227 .array_eq(&other.uncompressed_lengths, precision)
228 }
229}
230
231impl ValidityChild<FSSTVTable> for FSSTVTable {
232 fn validity_child(array: &FSSTArray) -> &dyn Array {
233 array.codes().as_ref()
234 }
235}