vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::ops::Range;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_array::accessor::ArrayAccessor;
10use vortex_array::arrays::binary_view::BinaryView;
11use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinViewArray};
12use vortex_array::compute::filter;
13use vortex_array::stats::{ArrayStats, StatsSetRef};
14use vortex_array::validity::Validity;
15use vortex_array::vtable::{
16    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
17    ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
18};
19use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
20use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
21use vortex_dtype::DType;
22use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
23use vortex_mask::AllOr;
24use vortex_scalar::Scalar;
25
26use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
27
28// Zstd doesn't support training dictionaries on very few samples.
29const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
30type ViewLen = u32;
31
32// Overall approach here:
33// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
34// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
35// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
36// want somewhat faster access to slices or individual rows, allowing us to only
37// decompress the necessary frames.
38
39// Visually, during decompression, we have an interval of frames we're
40// decompressing and a tighter interval of the slice we actually care about.
41// |=============values (all valid elements)==============|
42// |<-skipped_uncompressed->|----decompressed-------------|
43//                              |------slice-------|
44//                              ^                  ^
45// |<-slice_uncompressed_start->|                  |
46// |<------------slice_uncompressed_stop---------->|
47// We then insert these values to the correct position using a primitive array
48// constructor.
49
50vtable!(Zstd);
51
52impl VTable for ZstdVTable {
53    type Array = ZstdArray;
54    type Encoding = ZstdEncoding;
55
56    type ArrayVTable = Self;
57    type CanonicalVTable = Self;
58    type OperationsVTable = Self;
59    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
60    type VisitorVTable = Self;
61    type ComputeVTable = NotSupported;
62    type EncodeVTable = Self;
63    type SerdeVTable = Self;
64    type PipelineVTable = NotSupported;
65
66    fn id(_encoding: &Self::Encoding) -> EncodingId {
67        EncodingId::new_ref("vortex.zstd")
68    }
69
70    fn encoding(_array: &Self::Array) -> EncodingRef {
71        EncodingRef::new_ref(ZstdEncoding.as_ref())
72    }
73}
74
75#[derive(Clone, Debug)]
76pub struct ZstdEncoding;
77
78#[derive(Clone, Debug)]
79pub struct ZstdArray {
80    pub(crate) dictionary: Option<ByteBuffer>,
81    pub(crate) frames: Vec<ByteBuffer>,
82    pub(crate) metadata: ZstdMetadata,
83    dtype: DType,
84    pub(crate) unsliced_validity: Validity,
85    unsliced_n_rows: usize,
86    stats_set: ArrayStats,
87    slice_start: usize,
88    slice_stop: usize,
89}
90
91struct Frames {
92    dictionary: Option<ByteBuffer>,
93    frames: Vec<ByteBuffer>,
94    frame_metas: Vec<ZstdFrameMetadata>,
95}
96
97fn choose_max_dict_size(uncompressed_size: usize) -> usize {
98    // following recommendations from
99    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
100    // that is, 1/100 the data size, up to 100kB.
101    // It appears that zstd can't train dictionaries with <256 bytes.
102    (uncompressed_size / 100).clamp(256, 100 * 1024)
103}
104
105fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
106    let mask = parray.validity_mask();
107    Ok(filter(&parray.to_array(), &mask)?.to_primitive())
108}
109
110fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
111    let mask = vbv.validity_mask();
112    let buffer_and_value_byte_indices = match mask.boolean_buffer() {
113        AllOr::None => (Buffer::empty(), Vec::new()),
114        _ => {
115            let mut buffer = BufferMut::with_capacity(
116                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
117                    + mask.true_count() * size_of::<ViewLen>(),
118            );
119            let mut value_byte_indices = Vec::new();
120            vbv.with_iterator(|iterator| {
121                // by flattening, we should omit nulls
122                for value in iterator.flatten() {
123                    value_byte_indices.push(buffer.len());
124                    // here's where we write the string lengths
125                    buffer
126                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
127                    buffer.extend_from_slice(value);
128                }
129                Ok::<_, VortexError>(())
130            })??;
131            (buffer.freeze(), value_byte_indices)
132        }
133    };
134    Ok(buffer_and_value_byte_indices)
135}
136
137fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
138    let mut res = BufferMut::<BinaryView>::empty();
139    let mut offset = 0;
140    while offset < buffer.len() {
141        let str_len = ViewLen::from_le_bytes(
142            buffer
143                .get(offset..offset + size_of::<ViewLen>())
144                .vortex_expect("corrupted zstd length")
145                .try_into()
146                .vortex_expect("must fit ViewLen size"),
147        ) as usize;
148        offset += size_of::<ViewLen>();
149        let value = &buffer[offset..offset + str_len];
150        res.push(BinaryView::make_view(
151            value,
152            0,
153            u32::try_from(offset).vortex_expect("offset must fit in u32"),
154        ));
155        offset += str_len;
156    }
157    res.freeze()
158}
159
160impl ZstdArray {
161    pub fn new(
162        dictionary: Option<ByteBuffer>,
163        frames: Vec<ByteBuffer>,
164        dtype: DType,
165        metadata: ZstdMetadata,
166        n_rows: usize,
167        validity: Validity,
168    ) -> Self {
169        Self {
170            dictionary,
171            frames,
172            metadata,
173            dtype,
174            unsliced_validity: validity,
175            unsliced_n_rows: n_rows,
176            stats_set: Default::default(),
177            slice_start: 0,
178            slice_stop: n_rows,
179        }
180    }
181
182    fn compress_values(
183        value_bytes: &ByteBuffer,
184        frame_byte_starts: &[usize],
185        level: i32,
186        values_per_frame: usize,
187        n_values: usize,
188    ) -> VortexResult<Frames> {
189        let n_frames = frame_byte_starts.len();
190
191        // Would-be sample sizes if we end up applying zstd dictionary
192        let mut sample_sizes = Vec::with_capacity(n_frames);
193        for i in 0..n_frames {
194            let frame_byte_end = frame_byte_starts
195                .get(i + 1)
196                .copied()
197                .unwrap_or(value_bytes.len());
198            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
199        }
200        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
201
202        let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
203            // no dictionary
204            (None, zstd::bulk::Compressor::new(level)?)
205        } else {
206            // with dictionary
207            let max_dict_size = choose_max_dict_size(value_bytes.len());
208            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
209                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
210
211            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
212            (Some(ByteBuffer::from(dict)), compressor)
213        };
214
215        let mut frame_metas = vec![];
216        let mut frames = vec![];
217        for i in 0..n_frames {
218            let frame_byte_end = frame_byte_starts
219                .get(i + 1)
220                .copied()
221                .unwrap_or(value_bytes.len());
222            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
223            let compressed = compressor
224                .compress(uncompressed)
225                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
226            frame_metas.push(ZstdFrameMetadata {
227                uncompressed_size: uncompressed.len() as u64,
228                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
229            });
230            frames.push(ByteBuffer::from(compressed));
231        }
232
233        Ok(Frames {
234            dictionary,
235            frames,
236            frame_metas,
237        })
238    }
239
240    pub fn from_primitive(
241        parray: &PrimitiveArray,
242        level: i32,
243        values_per_frame: usize,
244    ) -> VortexResult<Self> {
245        let dtype = parray.dtype().clone();
246        let byte_width = parray.ptype().byte_width();
247
248        // We compress only the valid elements.
249        let values = collect_valid_primitive(parray)?;
250        let n_values = values.len();
251        let values_per_frame = if values_per_frame > 0 {
252            values_per_frame
253        } else {
254            n_values
255        };
256
257        let value_bytes = values.byte_buffer();
258        let frame_byte_starts = (0..n_values * byte_width)
259            .step_by(values_per_frame * byte_width)
260            .collect::<Vec<_>>();
261        let Frames {
262            dictionary,
263            frames,
264            frame_metas,
265        } = Self::compress_values(
266            value_bytes,
267            &frame_byte_starts,
268            level,
269            values_per_frame,
270            n_values,
271        )?;
272
273        let metadata = ZstdMetadata {
274            dictionary_size: dictionary
275                .as_ref()
276                .map_or(0, |dict| dict.len())
277                .try_into()?,
278            frames: frame_metas,
279        };
280
281        Ok(ZstdArray::new(
282            dictionary,
283            frames,
284            dtype,
285            metadata,
286            parray.len(),
287            parray.validity().clone(),
288        ))
289    }
290
291    pub fn from_var_bin_view(
292        vbv: &VarBinViewArray,
293        level: i32,
294        values_per_frame: usize,
295    ) -> VortexResult<Self> {
296        // Approach for strings: we prefix each string with its length as a u32.
297        // This is the same as what Parquet does. In some cases it may be better
298        // to separate the binary data and lengths as two separate streams, but
299        // this approach is simpler and can be best in cases when there is
300        // mutual information between strings and their lengths.
301        let dtype = vbv.dtype().clone();
302
303        // We compress only the valid elements.
304        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
305        let n_values = value_byte_indices.len();
306        let values_per_frame = if values_per_frame > 0 {
307            values_per_frame
308        } else {
309            n_values
310        };
311
312        let frame_byte_starts = (0..n_values)
313            .step_by(values_per_frame)
314            .map(|i| value_byte_indices[i])
315            .collect::<Vec<_>>();
316        let Frames {
317            dictionary,
318            frames,
319            frame_metas,
320        } = Self::compress_values(
321            &value_bytes,
322            &frame_byte_starts,
323            level,
324            values_per_frame,
325            n_values,
326        )?;
327
328        let metadata = ZstdMetadata {
329            dictionary_size: dictionary
330                .as_ref()
331                .map_or(0, |dict| dict.len())
332                .try_into()?,
333            frames: frame_metas,
334        };
335        Ok(ZstdArray::new(
336            dictionary,
337            frames,
338            dtype,
339            metadata,
340            vbv.len(),
341            vbv.validity().clone(),
342        ))
343    }
344
345    pub fn from_canonical(
346        canonical: &Canonical,
347        level: i32,
348        values_per_frame: usize,
349    ) -> VortexResult<Option<Self>> {
350        match canonical {
351            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
352                parray,
353                level,
354                values_per_frame,
355            )?)),
356            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
357                vbv,
358                level,
359                values_per_frame,
360            )?)),
361            _ => Ok(None),
362        }
363    }
364
365    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
366        Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
367            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
368    }
369
370    fn byte_width(&self) -> usize {
371        if self.dtype.is_primitive() {
372            self.dtype.as_ptype().byte_width()
373        } else {
374            1
375        }
376    }
377
378    pub fn decompress(&self) -> ArrayRef {
379        // To start, we figure out which frames we need to decompress, and with
380        // what row offset into the first such frame.
381        let byte_width = self.byte_width();
382        let slice_n_rows = self.slice_stop - self.slice_start;
383        let slice_value_indices = self
384            .unsliced_validity
385            .to_mask(self.unsliced_n_rows)
386            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
387
388        let slice_value_idx_start = slice_value_indices[0];
389        let slice_value_idx_stop = slice_value_indices[1];
390
391        let mut frames_to_decompress = vec![];
392        let mut value_idx_start = 0;
393        let mut uncompressed_size_to_decompress = 0;
394        let mut n_skipped_values = 0;
395        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
396            if value_idx_start >= slice_value_idx_stop {
397                break;
398            }
399
400            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
401                .vortex_expect("Uncompressed size must fit in usize");
402            let frame_n_values = if frame_meta.n_values == 0 {
403                // possibly older primitive-only metadata that just didn't store this
404                frame_uncompressed_size / byte_width
405            } else {
406                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
407            };
408
409            let value_idx_stop = value_idx_start + frame_n_values;
410            if value_idx_stop > slice_value_idx_start {
411                // we need this frame
412                frames_to_decompress.push(frame);
413                uncompressed_size_to_decompress += frame_uncompressed_size;
414            } else {
415                n_skipped_values += frame_n_values;
416            }
417            value_idx_start = value_idx_stop;
418        }
419
420        // then we actually decompress those frames
421        let mut decompressor = if let Some(dictionary) = &self.dictionary {
422            zstd::bulk::Decompressor::with_dictionary(dictionary)
423        } else {
424            zstd::bulk::Decompressor::new()
425        }
426        .vortex_expect("Decompressor encountered io error");
427        let mut decompressed = ByteBufferMut::with_capacity_aligned(
428            uncompressed_size_to_decompress,
429            Alignment::new(byte_width),
430        );
431        unsafe {
432            // safety: we immediately fill all bytes in the following loop,
433            // assuming our metadata's uncompressed size is correct
434            decompressed.set_len(uncompressed_size_to_decompress);
435        }
436        let mut uncompressed_start = 0;
437        for frame in frames_to_decompress {
438            let uncompressed_written = decompressor
439                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
440                .vortex_expect("error while decompressing zstd array");
441            uncompressed_start += uncompressed_written;
442        }
443        if uncompressed_start != uncompressed_size_to_decompress {
444            vortex_panic!(
445                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
446                uncompressed_size_to_decompress,
447                uncompressed_start
448            );
449        }
450
451        let decompressed = decompressed.freeze();
452        // Last, we slice the exact values requested out of the decompressed data.
453        let slice_validity = self
454            .unsliced_validity
455            .slice(self.slice_start..self.slice_stop);
456
457        match &self.dtype {
458            DType::Primitive(..) => {
459                let slice_values_buffer = decompressed.slice(
460                    (slice_value_idx_start - n_skipped_values) * byte_width
461                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
462                );
463                let primitive = PrimitiveArray::from_values_byte_buffer(
464                    slice_values_buffer,
465                    self.dtype.as_ptype(),
466                    slice_validity,
467                    slice_n_rows,
468                );
469
470                primitive.into_array()
471            }
472            DType::Binary(_) | DType::Utf8(_) => {
473                match slice_validity.to_mask(slice_n_rows).indices() {
474                    AllOr::All => {
475                        // the decompressed buffer is a bunch of interleaved u32 lengths
476                        // and strings of those lengths, we need to reconstruct the
477                        // views into those strings by passing through the buffer.
478                        let valid_views = reconstruct_views(decompressed.clone()).slice(
479                            slice_value_idx_start - n_skipped_values
480                                ..slice_value_idx_stop - n_skipped_values,
481                        );
482
483                        // SAFETY: we properly construct the views inside `reconstruct_views`
484                        unsafe {
485                            VarBinViewArray::new_unchecked(
486                                valid_views,
487                                Arc::from([decompressed]),
488                                self.dtype.clone(),
489                                slice_validity,
490                            )
491                        }
492                        .into_array()
493                    }
494                    AllOr::None => {
495                        ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows)
496                            .into_array()
497                    }
498                    AllOr::Some(valid_indices) => {
499                        // the decompressed buffer is a bunch of interleaved u32 lengths
500                        // and strings of those lengths, we need to reconstruct the
501                        // views into those strings by passing through the buffer.
502                        let valid_views = reconstruct_views(decompressed.clone()).slice(
503                            slice_value_idx_start - n_skipped_values
504                                ..slice_value_idx_stop - n_skipped_values,
505                        );
506
507                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
508                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
509                            views[*index] = view
510                        }
511
512                        // SAFETY: we properly construct the views inside `reconstruct_views`
513                        unsafe {
514                            VarBinViewArray::new_unchecked(
515                                views.freeze(),
516                                Arc::from([decompressed]),
517                                self.dtype.clone(),
518                                slice_validity,
519                            )
520                        }
521                        .into_array()
522                    }
523                }
524            }
525            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
526        }
527    }
528
529    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
530        ZstdArray {
531            slice_start: self.slice_start + start,
532            slice_stop: self.slice_start + stop,
533            stats_set: Default::default(),
534            ..self.clone()
535        }
536    }
537
538    pub(crate) fn dtype(&self) -> &DType {
539        &self.dtype
540    }
541
542    pub(crate) fn slice_start(&self) -> usize {
543        self.slice_start
544    }
545
546    pub(crate) fn slice_stop(&self) -> usize {
547        self.slice_stop
548    }
549
550    pub(crate) fn unsliced_n_rows(&self) -> usize {
551        self.unsliced_n_rows
552    }
553}
554
555impl ValiditySliceHelper for ZstdArray {
556    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
557        (&self.unsliced_validity, self.slice_start, self.slice_stop)
558    }
559}
560
561impl ArrayVTable<ZstdVTable> for ZstdVTable {
562    fn len(array: &ZstdArray) -> usize {
563        array.slice_stop - array.slice_start
564    }
565
566    fn dtype(array: &ZstdArray) -> &DType {
567        &array.dtype
568    }
569
570    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
571        array.stats_set.to_ref(array.as_ref())
572    }
573}
574
575impl CanonicalVTable<ZstdVTable> for ZstdVTable {
576    fn canonicalize(array: &ZstdArray) -> Canonical {
577        array.decompress().to_canonical()
578    }
579}
580
581impl OperationsVTable<ZstdVTable> for ZstdVTable {
582    fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
583        array._slice(range.start, range.end).into_array()
584    }
585
586    fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
587        array._slice(index, index + 1).decompress().scalar_at(0)
588    }
589}