vortex_btrblocks/
string.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::{
5    ConstantArray, MaskedArray, VarBinArray, VarBinViewArray, VarBinViewVTable,
6};
7use vortex_array::vtable::ValidityHelper;
8use vortex_array::{ArrayRef, IntoArray, ToCanonical};
9use vortex_dict::DictArray;
10use vortex_dict::builders::dict_encode;
11use vortex_error::{VortexExpect, VortexResult};
12use vortex_fsst::{FSSTArray, fsst_compress, fsst_train_compressor};
13use vortex_scalar::Scalar;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::integer::IntCompressor;
17use crate::sample::sample;
18use crate::{
19    Compressor, CompressorStats, GenerateStatsOptions, Scheme,
20    estimate_compression_ratio_with_sampling, integer,
21};
22
23/// Array of variable-length byte arrays, and relevant stats for compression.
24#[derive(Clone, Debug)]
25pub struct StringStats {
26    src: VarBinViewArray,
27    estimated_distinct_count: u32,
28    value_count: u32,
29}
30
31/// Estimate the number of distinct strings in the var bin view array.
32#[allow(clippy::cast_possible_truncation)]
33fn estimate_distinct_count(strings: &VarBinViewArray) -> u32 {
34    let views = strings.views();
35    // Iterate the views. Two strings which are equal must have the same first 8-bytes.
36    // NOTE: there are cases where this performs pessimally, e.g. when we have strings that all
37    // share a 4-byte prefix and have the same length.
38    let mut distinct = HashSet::with_capacity(views.len() / 2);
39    views.iter().for_each(|&view| {
40        let len_and_prefix = view.as_u128() as u64;
41        distinct.insert(len_and_prefix);
42    });
43
44    distinct
45        .len()
46        .try_into()
47        .vortex_expect("distinct count must fit in u32")
48}
49
50impl CompressorStats for StringStats {
51    type ArrayVTable = VarBinViewVTable;
52
53    fn generate_opts(input: &VarBinViewArray, opts: GenerateStatsOptions) -> Self {
54        let null_count = input
55            .statistics()
56            .compute_null_count()
57            .vortex_expect("null count");
58        let value_count = input.len() - null_count;
59        let estimated_distinct = if opts.count_distinct_values {
60            estimate_distinct_count(input)
61        } else {
62            u32::MAX
63        };
64
65        Self {
66            src: input.clone(),
67            value_count: value_count.try_into().vortex_expect("value_count"),
68            estimated_distinct_count: estimated_distinct,
69        }
70    }
71
72    fn source(&self) -> &VarBinViewArray {
73        &self.src
74    }
75
76    fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
77        let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_varbinview();
78
79        Self::generate_opts(&sampled, opts)
80    }
81}
82
83/// [`Compressor`] for strings.
84pub struct StringCompressor;
85
86impl Compressor for StringCompressor {
87    type ArrayVTable = VarBinViewVTable;
88    type SchemeType = dyn StringScheme;
89    type StatsType = StringStats;
90
91    fn schemes() -> &'static [&'static Self::SchemeType] {
92        &[
93            &UncompressedScheme,
94            &DictScheme,
95            &FSSTScheme,
96            &ConstantScheme,
97        ]
98    }
99
100    fn default_scheme() -> &'static Self::SchemeType {
101        &UncompressedScheme
102    }
103
104    fn dict_scheme_code() -> StringCode {
105        DICT_SCHEME
106    }
107}
108
109pub trait StringScheme: Scheme<StatsType = StringStats, CodeType = StringCode> {}
110
111impl<T> StringScheme for T where T: Scheme<StatsType = StringStats, CodeType = StringCode> {}
112
113#[derive(Debug, Copy, Clone)]
114pub struct UncompressedScheme;
115
116#[derive(Debug, Copy, Clone)]
117pub struct DictScheme;
118
119#[derive(Debug, Copy, Clone)]
120pub struct FSSTScheme;
121
122#[derive(Debug, Copy, Clone)]
123pub struct ConstantScheme;
124
125#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
126pub struct StringCode(u8);
127
128const UNCOMPRESSED_SCHEME: StringCode = StringCode(0);
129const DICT_SCHEME: StringCode = StringCode(1);
130const FSST_SCHEME: StringCode = StringCode(2);
131const CONSTANT_SCHEME: StringCode = StringCode(3);
132
133impl Scheme for UncompressedScheme {
134    type StatsType = StringStats;
135    type CodeType = StringCode;
136
137    fn code(&self) -> StringCode {
138        UNCOMPRESSED_SCHEME
139    }
140
141    fn expected_compression_ratio(
142        &self,
143        _stats: &Self::StatsType,
144        _is_sample: bool,
145        _allowed_cascading: usize,
146        _excludes: &[StringCode],
147    ) -> VortexResult<f64> {
148        Ok(1.0)
149    }
150
151    fn compress(
152        &self,
153        stats: &Self::StatsType,
154        _is_sample: bool,
155        _allowed_cascading: usize,
156        _excludes: &[StringCode],
157    ) -> VortexResult<ArrayRef> {
158        Ok(stats.source().to_array())
159    }
160}
161
162impl Scheme for DictScheme {
163    type StatsType = StringStats;
164    type CodeType = StringCode;
165
166    fn code(&self) -> StringCode {
167        DICT_SCHEME
168    }
169
170    fn expected_compression_ratio(
171        &self,
172        stats: &Self::StatsType,
173        is_sample: bool,
174        allowed_cascading: usize,
175        excludes: &[StringCode],
176    ) -> VortexResult<f64> {
177        // If we don't have a sufficiently high number of distinct values, do not attempt Dict.
178        if stats.estimated_distinct_count > stats.value_count / 2 {
179            return Ok(0.0);
180        }
181
182        // If array is all null, do not attempt dict.
183        if stats.value_count == 0 {
184            return Ok(0.0);
185        }
186
187        estimate_compression_ratio_with_sampling(
188            self,
189            stats,
190            is_sample,
191            allowed_cascading,
192            excludes,
193        )
194    }
195
196    fn compress(
197        &self,
198        stats: &Self::StatsType,
199        is_sample: bool,
200        allowed_cascading: usize,
201        _excludes: &[StringCode],
202    ) -> VortexResult<ArrayRef> {
203        let dict = dict_encode(&stats.source().clone().into_array())?;
204
205        // If we are not allowed to cascade, do not attempt codes or values compression.
206        if allowed_cascading == 0 {
207            return Ok(dict.into_array());
208        }
209
210        // Find best compressor for codes and values separately
211        let compressed_codes = IntCompressor::compress(
212            &dict.codes().to_primitive(),
213            is_sample,
214            allowed_cascading - 1,
215            &[integer::DictScheme.code(), integer::SequenceScheme.code()],
216        )?;
217
218        // Attempt to compress the values with non-Dict compression.
219        // Currently this will only be FSST.
220        let compressed_values = StringCompressor::compress(
221            &dict.values().to_varbinview(),
222            is_sample,
223            allowed_cascading - 1,
224            &[DictScheme.code()],
225        )?;
226
227        // SAFETY: compressing codes or values does not alter the invariants
228        unsafe { Ok(DictArray::new_unchecked(compressed_codes, compressed_values).into_array()) }
229    }
230}
231
232impl Scheme for FSSTScheme {
233    type StatsType = StringStats;
234    type CodeType = StringCode;
235
236    fn code(&self) -> StringCode {
237        FSST_SCHEME
238    }
239
240    fn compress(
241        &self,
242        stats: &Self::StatsType,
243        is_sample: bool,
244        allowed_cascading: usize,
245        _excludes: &[StringCode],
246    ) -> VortexResult<ArrayRef> {
247        let compressor = fsst_train_compressor(&stats.src.clone().into_array())?;
248        let fsst = fsst_compress(&stats.src.clone().into_array(), &compressor)?;
249
250        let compressed_original_lengths = IntCompressor::compress(
251            &fsst.uncompressed_lengths().to_primitive().narrow()?,
252            is_sample,
253            allowed_cascading,
254            &[],
255        )?;
256
257        let compressed_codes_offsets = IntCompressor::compress(
258            &fsst.codes().offsets().to_primitive().narrow()?,
259            is_sample,
260            allowed_cascading,
261            &[],
262        )?;
263        let compressed_codes = VarBinArray::try_new(
264            compressed_codes_offsets,
265            fsst.codes().bytes().clone(),
266            fsst.codes().dtype().clone(),
267            fsst.codes().validity().clone(),
268        )?;
269
270        let fsst = FSSTArray::try_new(
271            fsst.dtype().clone(),
272            fsst.symbols().clone(),
273            fsst.symbol_lengths().clone(),
274            compressed_codes,
275            compressed_original_lengths,
276        )?;
277
278        Ok(fsst.into_array())
279    }
280}
281
282impl Scheme for ConstantScheme {
283    type StatsType = StringStats;
284    type CodeType = StringCode;
285
286    fn code(&self) -> Self::CodeType {
287        CONSTANT_SCHEME
288    }
289
290    fn is_constant(&self) -> bool {
291        true
292    }
293
294    fn expected_compression_ratio(
295        &self,
296        stats: &Self::StatsType,
297        is_sample: bool,
298        _allowed_cascading: usize,
299        _excludes: &[Self::CodeType],
300    ) -> VortexResult<f64> {
301        if is_sample {
302            return Ok(0.0);
303        }
304
305        if stats.estimated_distinct_count > 1 || !stats.src.is_constant() {
306            return Ok(0.0);
307        }
308
309        // Force constant is these cases
310        Ok(f64::MAX)
311    }
312
313    fn compress(
314        &self,
315        stats: &Self::StatsType,
316        _is_sample: bool,
317        _allowed_cascading: usize,
318        _excludes: &[Self::CodeType],
319    ) -> VortexResult<ArrayRef> {
320        let scalar_idx = (0..stats.source().len()).position(|idx| stats.source().is_valid(idx));
321
322        match scalar_idx {
323            Some(idx) => {
324                let scalar = stats.source().scalar_at(idx);
325                let const_arr = ConstantArray::new(scalar, stats.src.len()).into_array();
326                if !stats.source().all_valid() {
327                    Ok(MaskedArray::try_new(const_arr, stats.src.validity().clone())?.into_array())
328                } else {
329                    Ok(const_arr)
330                }
331            }
332            None => Ok(ConstantArray::new(
333                Scalar::null(stats.src.dtype().clone()),
334                stats.src.len(),
335            )
336            .into_array()),
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use vortex_array::arrays::VarBinViewArray;
344    use vortex_dtype::{DType, Nullability};
345
346    use crate::Compressor;
347    use crate::string::StringCompressor;
348
349    #[test]
350    fn test_strings() {
351        let mut strings = Vec::new();
352        for _ in 0..1024 {
353            strings.push(Some("hello-world-1234"));
354        }
355        for _ in 0..1024 {
356            strings.push(Some("hello-world-56789"));
357        }
358        let strings = VarBinViewArray::from_iter(strings, DType::Utf8(Nullability::NonNullable));
359
360        println!("original array: {}", strings.as_ref().display_tree());
361
362        let compressed = StringCompressor::compress(&strings, false, 3, &[]).unwrap();
363
364        println!("compression tree: {}", compressed.display_tree());
365    }
366}