vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::sync::Arc;
6
7use vortex_array::accessor::ArrayAccessor;
8use vortex_array::arrays::{BinaryView, PrimitiveArray, VarBinViewArray};
9use vortex_array::compute::filter;
10use vortex_array::stats::{ArrayStats, StatsSetRef};
11use vortex_array::validity::Validity;
12use vortex_array::vtable::{
13    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
14    ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
15};
16use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
17use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
18use vortex_dtype::DType;
19use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
20use vortex_mask::AllOr;
21use vortex_scalar::Scalar;
22
23use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
24
25// Zstd doesn't support training dictionaries on very few samples.
26const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
27type ViewLen = u32;
28
29// Overall approach here:
30// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
31// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
32// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
33// want somewhat faster access to slices or individual rows, allowing us to only
34// decompress the necessary frames.
35
36// Visually, during decompression, we have an interval of frames we're
37// decompressing and a tighter interval of the slice we actually care about.
38// |=============values (all valid elements)==============|
39// |<-skipped_uncompressed->|----decompressed-------------|
40//                              |------slice-------|
41//                              ^                  ^
42// |<-slice_uncompressed_start->|                  |
43// |<------------slice_uncompressed_stop---------->|
44// We then insert these values to the correct position using a primitive array
45// constructor.
46
47vtable!(Zstd);
48
49impl VTable for ZstdVTable {
50    type Array = ZstdArray;
51    type Encoding = ZstdEncoding;
52
53    type ArrayVTable = Self;
54    type CanonicalVTable = Self;
55    type OperationsVTable = Self;
56    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
57    type VisitorVTable = Self;
58    type ComputeVTable = NotSupported;
59    type EncodeVTable = Self;
60    type SerdeVTable = Self;
61    type PipelineVTable = NotSupported;
62
63    fn id(_encoding: &Self::Encoding) -> EncodingId {
64        EncodingId::new_ref("vortex.zstd")
65    }
66
67    fn encoding(_array: &Self::Array) -> EncodingRef {
68        EncodingRef::new_ref(ZstdEncoding.as_ref())
69    }
70}
71
72#[derive(Clone, Debug)]
73pub struct ZstdEncoding;
74
75#[derive(Clone, Debug)]
76pub struct ZstdArray {
77    pub(crate) dictionary: Option<ByteBuffer>,
78    pub(crate) frames: Vec<ByteBuffer>,
79    pub(crate) metadata: ZstdMetadata,
80    dtype: DType,
81    pub(crate) unsliced_validity: Validity,
82    unsliced_n_rows: usize,
83    stats_set: ArrayStats,
84    slice_start: usize,
85    slice_stop: usize,
86}
87
88struct Frames {
89    dictionary: Option<ByteBuffer>,
90    frames: Vec<ByteBuffer>,
91    frame_metas: Vec<ZstdFrameMetadata>,
92}
93
94fn choose_max_dict_size(uncompressed_size: usize) -> usize {
95    // following recommendations from
96    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
97    // that is, 1/100 the data size, up to 100kB.
98    // It appears that zstd can't train dictionaries with <256 bytes.
99    (uncompressed_size / 100).clamp(256, 100 * 1024)
100}
101
102fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
103    let mask = parray.validity_mask()?;
104    filter(&parray.to_array(), &mask)?.to_primitive()
105}
106
107fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
108    let mask = vbv.validity_mask()?;
109    let buffer_and_value_byte_indices = match mask.boolean_buffer() {
110        AllOr::None => (Buffer::empty(), Vec::new()),
111        _ => {
112            let mut buffer = BufferMut::with_capacity(
113                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
114                    + mask.true_count() * size_of::<ViewLen>(),
115            );
116            let mut value_byte_indices = Vec::new();
117            vbv.with_iterator(|iterator| {
118                // by flattening, we should omit nulls
119                for value in iterator.flatten() {
120                    value_byte_indices.push(buffer.len());
121                    // here's where we write the string lengths
122                    buffer
123                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
124                    buffer.extend_from_slice(value);
125                }
126                Ok::<_, VortexError>(())
127            })??;
128            (buffer.freeze(), value_byte_indices)
129        }
130    };
131    Ok(buffer_and_value_byte_indices)
132}
133
134fn reconstruct_views(buffer: ByteBuffer) -> VortexResult<Buffer<BinaryView>> {
135    let mut res = BufferMut::<BinaryView>::empty();
136    let mut offset = 0;
137    while offset < buffer.len() {
138        let str_len = ViewLen::from_le_bytes(
139            buffer
140                .get(offset..offset + size_of::<ViewLen>())
141                .ok_or_else(|| vortex_err!("Zstd buffer for VarBinView was corrupt"))?
142                .try_into()?,
143        ) as usize;
144        offset += size_of::<ViewLen>();
145        let value = &buffer[offset..offset + str_len];
146        res.push(BinaryView::make_view(value, 0, u32::try_from(offset)?));
147        offset += str_len;
148    }
149    Ok(res.freeze())
150}
151
152impl ZstdArray {
153    pub fn new(
154        dictionary: Option<ByteBuffer>,
155        frames: Vec<ByteBuffer>,
156        dtype: DType,
157        metadata: ZstdMetadata,
158        n_rows: usize,
159        validity: Validity,
160    ) -> Self {
161        Self {
162            dictionary,
163            frames,
164            metadata,
165            dtype,
166            unsliced_validity: validity,
167            unsliced_n_rows: n_rows,
168            stats_set: Default::default(),
169            slice_start: 0,
170            slice_stop: n_rows,
171        }
172    }
173
174    fn compress_values(
175        value_bytes: &ByteBuffer,
176        frame_byte_starts: &[usize],
177        level: i32,
178        values_per_frame: usize,
179        n_values: usize,
180    ) -> VortexResult<Frames> {
181        let n_frames = frame_byte_starts.len();
182
183        // Would-be sample sizes if we end up applying zstd dictionary
184        let mut sample_sizes = Vec::with_capacity(n_frames);
185        for i in 0..n_frames {
186            let frame_byte_end = frame_byte_starts
187                .get(i + 1)
188                .copied()
189                .unwrap_or(value_bytes.len());
190            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
191        }
192        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
193
194        let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
195            // no dictionary
196            (None, zstd::bulk::Compressor::new(level)?)
197        } else {
198            // with dictionary
199            let max_dict_size = choose_max_dict_size(value_bytes.len());
200            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
201                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
202
203            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
204            (Some(ByteBuffer::from(dict)), compressor)
205        };
206
207        let mut frame_metas = vec![];
208        let mut frames = vec![];
209        for i in 0..n_frames {
210            let frame_byte_end = frame_byte_starts
211                .get(i + 1)
212                .copied()
213                .unwrap_or(value_bytes.len());
214            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
215            let compressed = compressor
216                .compress(uncompressed)
217                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
218            frame_metas.push(ZstdFrameMetadata {
219                uncompressed_size: uncompressed.len() as u64,
220                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
221            });
222            frames.push(ByteBuffer::from(compressed));
223        }
224
225        Ok(Frames {
226            dictionary,
227            frames,
228            frame_metas,
229        })
230    }
231
232    pub fn from_primitive(
233        parray: &PrimitiveArray,
234        level: i32,
235        values_per_frame: usize,
236    ) -> VortexResult<Self> {
237        let dtype = parray.dtype().clone();
238        let byte_width = parray.ptype().byte_width();
239
240        // We compress only the valid elements.
241        let values = collect_valid_primitive(parray)?;
242        let n_values = values.len();
243        let values_per_frame = if values_per_frame > 0 {
244            values_per_frame
245        } else {
246            n_values
247        };
248
249        let value_bytes = values.byte_buffer();
250        let frame_byte_starts = (0..n_values * byte_width)
251            .step_by(values_per_frame * byte_width)
252            .collect::<Vec<_>>();
253        let Frames {
254            dictionary,
255            frames,
256            frame_metas,
257        } = Self::compress_values(
258            value_bytes,
259            &frame_byte_starts,
260            level,
261            values_per_frame,
262            n_values,
263        )?;
264
265        let metadata = ZstdMetadata {
266            dictionary_size: dictionary
267                .as_ref()
268                .map_or(0, |dict| dict.len())
269                .try_into()?,
270            frames: frame_metas,
271        };
272
273        Ok(ZstdArray::new(
274            dictionary,
275            frames,
276            dtype,
277            metadata,
278            parray.len(),
279            parray.validity().clone(),
280        ))
281    }
282
283    pub fn from_var_bin_view(
284        vbv: &VarBinViewArray,
285        level: i32,
286        values_per_frame: usize,
287    ) -> VortexResult<Self> {
288        // Approach for strings: we prefix each string with its length as a u32.
289        // This is the same as what Parquet does. In some cases it may be better
290        // to separate the binary data and lengths as two separate streams, but
291        // this approach is simpler and can be best in cases when there is
292        // mutual information between strings and their lengths.
293        let dtype = vbv.dtype().clone();
294
295        // We compress only the valid elements.
296        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
297        let n_values = value_byte_indices.len();
298        let values_per_frame = if values_per_frame > 0 {
299            values_per_frame
300        } else {
301            n_values
302        };
303
304        let frame_byte_starts = (0..n_values)
305            .step_by(values_per_frame)
306            .map(|i| value_byte_indices[i])
307            .collect::<Vec<_>>();
308        let Frames {
309            dictionary,
310            frames,
311            frame_metas,
312        } = Self::compress_values(
313            &value_bytes,
314            &frame_byte_starts,
315            level,
316            values_per_frame,
317            n_values,
318        )?;
319
320        let metadata = ZstdMetadata {
321            dictionary_size: dictionary
322                .as_ref()
323                .map_or(0, |dict| dict.len())
324                .try_into()?,
325            frames: frame_metas,
326        };
327        Ok(ZstdArray::new(
328            dictionary,
329            frames,
330            dtype,
331            metadata,
332            vbv.len(),
333            vbv.validity().clone(),
334        ))
335    }
336
337    pub fn from_canonical(
338        canonical: &Canonical,
339        level: i32,
340        values_per_frame: usize,
341    ) -> VortexResult<Option<Self>> {
342        match canonical {
343            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
344                parray,
345                level,
346                values_per_frame,
347            )?)),
348            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
349                vbv,
350                level,
351                values_per_frame,
352            )?)),
353            _ => Ok(None),
354        }
355    }
356
357    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
358        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
359            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
360    }
361
362    fn byte_width(&self) -> usize {
363        if self.dtype.is_primitive() {
364            self.dtype.as_ptype().byte_width()
365        } else {
366            1
367        }
368    }
369
370    pub fn decompress(&self) -> VortexResult<ArrayRef> {
371        // To start, we figure out which frames we need to decompress, and with
372        // what row offset into the first such frame.
373        let byte_width = self.byte_width();
374        let slice_n_rows = self.slice_stop - self.slice_start;
375        let slice_value_indices = self
376            .unsliced_validity
377            .to_mask(self.unsliced_n_rows)?
378            .valid_counts_for_indices(&[self.slice_start, self.slice_stop])?;
379
380        let slice_value_idx_start = slice_value_indices[0];
381        let slice_value_idx_stop = slice_value_indices[1];
382
383        let mut frames_to_decompress = vec![];
384        let mut value_idx_start = 0;
385        let mut uncompressed_size_to_decompress = 0;
386        let mut n_skipped_values = 0;
387        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
388            if value_idx_start >= slice_value_idx_stop {
389                break;
390            }
391
392            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)?;
393            let frame_n_values = if frame_meta.n_values == 0 {
394                // possibly older primitive-only metadata that just didn't store this
395                frame_uncompressed_size / byte_width
396            } else {
397                usize::try_from(frame_meta.n_values)?
398            };
399
400            let value_idx_stop = value_idx_start + frame_n_values;
401            if value_idx_stop > slice_value_idx_start {
402                // we need this frame
403                frames_to_decompress.push(frame);
404                uncompressed_size_to_decompress += frame_uncompressed_size;
405            } else {
406                n_skipped_values += frame_n_values;
407            }
408            value_idx_start = value_idx_stop;
409        }
410
411        // then we actually decompress those frames
412        let mut decompressor = if let Some(dictionary) = &self.dictionary {
413            zstd::bulk::Decompressor::with_dictionary(dictionary)
414        } else {
415            zstd::bulk::Decompressor::new()
416        }?;
417        let mut decompressed = ByteBufferMut::with_capacity_aligned(
418            uncompressed_size_to_decompress,
419            Alignment::new(byte_width),
420        );
421        unsafe {
422            // safety: we immediately fill all bytes in the following loop,
423            // assuming our metadata's uncompressed size is correct
424            decompressed.set_len(uncompressed_size_to_decompress);
425        }
426        let mut uncompressed_start = 0;
427        for frame in frames_to_decompress {
428            let uncompressed_written = decompressor
429                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
430                .map_err(|err| VortexError::from(err).with_context("while decompressing"))?;
431            uncompressed_start += uncompressed_written;
432        }
433        if uncompressed_start != uncompressed_size_to_decompress {
434            vortex_bail!(
435                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
436                uncompressed_size_to_decompress,
437                uncompressed_start
438            );
439        }
440
441        let decompressed = decompressed.freeze();
442        // Last, we slice the exact values requested out of the decompressed data.
443        let slice_validity = self
444            .unsliced_validity
445            .slice(self.slice_start, self.slice_stop);
446
447        match &self.dtype {
448            DType::Primitive(..) => {
449                let slice_values_buffer = decompressed.slice(
450                    (slice_value_idx_start - n_skipped_values) * byte_width
451                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
452                );
453                let primitive = PrimitiveArray::from_values_byte_buffer(
454                    slice_values_buffer,
455                    self.dtype.as_ptype(),
456                    slice_validity,
457                    slice_n_rows,
458                )?;
459
460                Ok(primitive.into_array())
461            }
462            DType::Binary(_) | DType::Utf8(_) => {
463                // The decompressed buffer is a bunch of interleaved u32 lengths
464                // and strings of those lengths, we need to reconstruct the
465                // views into those strings by passing through the buffer.
466                let views = reconstruct_views(decompressed.clone())?.slice(
467                    slice_value_idx_start - n_skipped_values
468                        ..slice_value_idx_stop - n_skipped_values,
469                );
470
471                // SAFETY: we properly construct the views inside `reconstruct_views`
472                let vbv = unsafe {
473                    VarBinViewArray::new_unchecked(
474                        views,
475                        Arc::from([decompressed]),
476                        self.dtype.clone(),
477                        slice_validity,
478                    )
479                };
480                Ok(vbv.into_array())
481            }
482            _ => Err(vortex_err!(
483                "Unsupported dtype for Zstd array: {:?}",
484                self.dtype
485            )),
486        }
487    }
488
489    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
490        ZstdArray {
491            slice_start: self.slice_start + start,
492            slice_stop: self.slice_start + stop,
493            stats_set: Default::default(),
494            ..self.clone()
495        }
496    }
497
498    pub(crate) fn dtype(&self) -> &DType {
499        &self.dtype
500    }
501
502    pub(crate) fn slice_start(&self) -> usize {
503        self.slice_start
504    }
505
506    pub(crate) fn slice_stop(&self) -> usize {
507        self.slice_stop
508    }
509
510    pub(crate) fn unsliced_n_rows(&self) -> usize {
511        self.unsliced_n_rows
512    }
513}
514
515impl ValiditySliceHelper for ZstdArray {
516    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
517        (&self.unsliced_validity, self.slice_start, self.slice_stop)
518    }
519}
520
521impl ArrayVTable<ZstdVTable> for ZstdVTable {
522    fn len(array: &ZstdArray) -> usize {
523        array.slice_stop - array.slice_start
524    }
525
526    fn dtype(array: &ZstdArray) -> &DType {
527        &array.dtype
528    }
529
530    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
531        array.stats_set.to_ref(array.as_ref())
532    }
533}
534
535impl CanonicalVTable<ZstdVTable> for ZstdVTable {
536    fn canonicalize(array: &ZstdArray) -> VortexResult<Canonical> {
537        array.decompress()?.to_canonical()
538    }
539}
540
541impl OperationsVTable<ZstdVTable> for ZstdVTable {
542    fn slice(array: &ZstdArray, start: usize, stop: usize) -> ArrayRef {
543        array._slice(start, stop).into_array()
544    }
545
546    fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
547        array
548            ._slice(index, index + 1)
549            .decompress()
550            .vortex_expect("zstd decompress")
551            .scalar_at(0)
552    }
553}