vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::ops::Range;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_array::accessor::ArrayAccessor;
10use vortex_array::arrays::{BinaryView, ConstantArray, PrimitiveArray, VarBinViewArray};
11use vortex_array::compute::filter;
12use vortex_array::stats::{ArrayStats, StatsSetRef};
13use vortex_array::validity::Validity;
14use vortex_array::vtable::{
15    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
16    ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
17};
18use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
19use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
20use vortex_dtype::DType;
21use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
22use vortex_mask::AllOr;
23use vortex_scalar::Scalar;
24
25use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
26
27// Zstd doesn't support training dictionaries on very few samples.
28const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
29type ViewLen = u32;
30
31// Overall approach here:
32// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
33// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
34// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
35// want somewhat faster access to slices or individual rows, allowing us to only
36// decompress the necessary frames.
37
38// Visually, during decompression, we have an interval of frames we're
39// decompressing and a tighter interval of the slice we actually care about.
40// |=============values (all valid elements)==============|
41// |<-skipped_uncompressed->|----decompressed-------------|
42//                              |------slice-------|
43//                              ^                  ^
44// |<-slice_uncompressed_start->|                  |
45// |<------------slice_uncompressed_stop---------->|
46// We then insert these values to the correct position using a primitive array
47// constructor.
48
49vtable!(Zstd);
50
51impl VTable for ZstdVTable {
52    type Array = ZstdArray;
53    type Encoding = ZstdEncoding;
54
55    type ArrayVTable = Self;
56    type CanonicalVTable = Self;
57    type OperationsVTable = Self;
58    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
59    type VisitorVTable = Self;
60    type ComputeVTable = NotSupported;
61    type EncodeVTable = Self;
62    type SerdeVTable = Self;
63    type PipelineVTable = NotSupported;
64
65    fn id(_encoding: &Self::Encoding) -> EncodingId {
66        EncodingId::new_ref("vortex.zstd")
67    }
68
69    fn encoding(_array: &Self::Array) -> EncodingRef {
70        EncodingRef::new_ref(ZstdEncoding.as_ref())
71    }
72}
73
74#[derive(Clone, Debug)]
75pub struct ZstdEncoding;
76
77#[derive(Clone, Debug)]
78pub struct ZstdArray {
79    pub(crate) dictionary: Option<ByteBuffer>,
80    pub(crate) frames: Vec<ByteBuffer>,
81    pub(crate) metadata: ZstdMetadata,
82    dtype: DType,
83    pub(crate) unsliced_validity: Validity,
84    unsliced_n_rows: usize,
85    stats_set: ArrayStats,
86    slice_start: usize,
87    slice_stop: usize,
88}
89
90struct Frames {
91    dictionary: Option<ByteBuffer>,
92    frames: Vec<ByteBuffer>,
93    frame_metas: Vec<ZstdFrameMetadata>,
94}
95
96fn choose_max_dict_size(uncompressed_size: usize) -> usize {
97    // following recommendations from
98    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
99    // that is, 1/100 the data size, up to 100kB.
100    // It appears that zstd can't train dictionaries with <256 bytes.
101    (uncompressed_size / 100).clamp(256, 100 * 1024)
102}
103
104fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
105    let mask = parray.validity_mask();
106    Ok(filter(&parray.to_array(), &mask)?.to_primitive())
107}
108
109fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
110    let mask = vbv.validity_mask();
111    let buffer_and_value_byte_indices = match mask.boolean_buffer() {
112        AllOr::None => (Buffer::empty(), Vec::new()),
113        _ => {
114            let mut buffer = BufferMut::with_capacity(
115                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
116                    + mask.true_count() * size_of::<ViewLen>(),
117            );
118            let mut value_byte_indices = Vec::new();
119            vbv.with_iterator(|iterator| {
120                // by flattening, we should omit nulls
121                for value in iterator.flatten() {
122                    value_byte_indices.push(buffer.len());
123                    // here's where we write the string lengths
124                    buffer
125                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
126                    buffer.extend_from_slice(value);
127                }
128                Ok::<_, VortexError>(())
129            })??;
130            (buffer.freeze(), value_byte_indices)
131        }
132    };
133    Ok(buffer_and_value_byte_indices)
134}
135
136fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
137    let mut res = BufferMut::<BinaryView>::empty();
138    let mut offset = 0;
139    while offset < buffer.len() {
140        let str_len = ViewLen::from_le_bytes(
141            buffer
142                .get(offset..offset + size_of::<ViewLen>())
143                .vortex_expect("corrupted zstd length")
144                .try_into()
145                .vortex_expect("must fit ViewLen size"),
146        ) as usize;
147        offset += size_of::<ViewLen>();
148        let value = &buffer[offset..offset + str_len];
149        res.push(BinaryView::make_view(
150            value,
151            0,
152            u32::try_from(offset).vortex_expect("offset must fit in u32"),
153        ));
154        offset += str_len;
155    }
156    res.freeze()
157}
158
159impl ZstdArray {
160    pub fn new(
161        dictionary: Option<ByteBuffer>,
162        frames: Vec<ByteBuffer>,
163        dtype: DType,
164        metadata: ZstdMetadata,
165        n_rows: usize,
166        validity: Validity,
167    ) -> Self {
168        Self {
169            dictionary,
170            frames,
171            metadata,
172            dtype,
173            unsliced_validity: validity,
174            unsliced_n_rows: n_rows,
175            stats_set: Default::default(),
176            slice_start: 0,
177            slice_stop: n_rows,
178        }
179    }
180
181    fn compress_values(
182        value_bytes: &ByteBuffer,
183        frame_byte_starts: &[usize],
184        level: i32,
185        values_per_frame: usize,
186        n_values: usize,
187    ) -> VortexResult<Frames> {
188        let n_frames = frame_byte_starts.len();
189
190        // Would-be sample sizes if we end up applying zstd dictionary
191        let mut sample_sizes = Vec::with_capacity(n_frames);
192        for i in 0..n_frames {
193            let frame_byte_end = frame_byte_starts
194                .get(i + 1)
195                .copied()
196                .unwrap_or(value_bytes.len());
197            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
198        }
199        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
200
201        let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
202            // no dictionary
203            (None, zstd::bulk::Compressor::new(level)?)
204        } else {
205            // with dictionary
206            let max_dict_size = choose_max_dict_size(value_bytes.len());
207            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
208                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
209
210            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
211            (Some(ByteBuffer::from(dict)), compressor)
212        };
213
214        let mut frame_metas = vec![];
215        let mut frames = vec![];
216        for i in 0..n_frames {
217            let frame_byte_end = frame_byte_starts
218                .get(i + 1)
219                .copied()
220                .unwrap_or(value_bytes.len());
221            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
222            let compressed = compressor
223                .compress(uncompressed)
224                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
225            frame_metas.push(ZstdFrameMetadata {
226                uncompressed_size: uncompressed.len() as u64,
227                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
228            });
229            frames.push(ByteBuffer::from(compressed));
230        }
231
232        Ok(Frames {
233            dictionary,
234            frames,
235            frame_metas,
236        })
237    }
238
239    pub fn from_primitive(
240        parray: &PrimitiveArray,
241        level: i32,
242        values_per_frame: usize,
243    ) -> VortexResult<Self> {
244        let dtype = parray.dtype().clone();
245        let byte_width = parray.ptype().byte_width();
246
247        // We compress only the valid elements.
248        let values = collect_valid_primitive(parray)?;
249        let n_values = values.len();
250        let values_per_frame = if values_per_frame > 0 {
251            values_per_frame
252        } else {
253            n_values
254        };
255
256        let value_bytes = values.byte_buffer();
257        let frame_byte_starts = (0..n_values * byte_width)
258            .step_by(values_per_frame * byte_width)
259            .collect::<Vec<_>>();
260        let Frames {
261            dictionary,
262            frames,
263            frame_metas,
264        } = Self::compress_values(
265            value_bytes,
266            &frame_byte_starts,
267            level,
268            values_per_frame,
269            n_values,
270        )?;
271
272        let metadata = ZstdMetadata {
273            dictionary_size: dictionary
274                .as_ref()
275                .map_or(0, |dict| dict.len())
276                .try_into()?,
277            frames: frame_metas,
278        };
279
280        Ok(ZstdArray::new(
281            dictionary,
282            frames,
283            dtype,
284            metadata,
285            parray.len(),
286            parray.validity().clone(),
287        ))
288    }
289
290    pub fn from_var_bin_view(
291        vbv: &VarBinViewArray,
292        level: i32,
293        values_per_frame: usize,
294    ) -> VortexResult<Self> {
295        // Approach for strings: we prefix each string with its length as a u32.
296        // This is the same as what Parquet does. In some cases it may be better
297        // to separate the binary data and lengths as two separate streams, but
298        // this approach is simpler and can be best in cases when there is
299        // mutual information between strings and their lengths.
300        let dtype = vbv.dtype().clone();
301
302        // We compress only the valid elements.
303        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
304        let n_values = value_byte_indices.len();
305        let values_per_frame = if values_per_frame > 0 {
306            values_per_frame
307        } else {
308            n_values
309        };
310
311        let frame_byte_starts = (0..n_values)
312            .step_by(values_per_frame)
313            .map(|i| value_byte_indices[i])
314            .collect::<Vec<_>>();
315        let Frames {
316            dictionary,
317            frames,
318            frame_metas,
319        } = Self::compress_values(
320            &value_bytes,
321            &frame_byte_starts,
322            level,
323            values_per_frame,
324            n_values,
325        )?;
326
327        let metadata = ZstdMetadata {
328            dictionary_size: dictionary
329                .as_ref()
330                .map_or(0, |dict| dict.len())
331                .try_into()?,
332            frames: frame_metas,
333        };
334        Ok(ZstdArray::new(
335            dictionary,
336            frames,
337            dtype,
338            metadata,
339            vbv.len(),
340            vbv.validity().clone(),
341        ))
342    }
343
344    pub fn from_canonical(
345        canonical: &Canonical,
346        level: i32,
347        values_per_frame: usize,
348    ) -> VortexResult<Option<Self>> {
349        match canonical {
350            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
351                parray,
352                level,
353                values_per_frame,
354            )?)),
355            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
356                vbv,
357                level,
358                values_per_frame,
359            )?)),
360            _ => Ok(None),
361        }
362    }
363
364    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
365        Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
366            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
367    }
368
369    fn byte_width(&self) -> usize {
370        if self.dtype.is_primitive() {
371            self.dtype.as_ptype().byte_width()
372        } else {
373            1
374        }
375    }
376
377    pub fn decompress(&self) -> ArrayRef {
378        // To start, we figure out which frames we need to decompress, and with
379        // what row offset into the first such frame.
380        let byte_width = self.byte_width();
381        let slice_n_rows = self.slice_stop - self.slice_start;
382        let slice_value_indices = self
383            .unsliced_validity
384            .to_mask(self.unsliced_n_rows)
385            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
386
387        let slice_value_idx_start = slice_value_indices[0];
388        let slice_value_idx_stop = slice_value_indices[1];
389
390        let mut frames_to_decompress = vec![];
391        let mut value_idx_start = 0;
392        let mut uncompressed_size_to_decompress = 0;
393        let mut n_skipped_values = 0;
394        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
395            if value_idx_start >= slice_value_idx_stop {
396                break;
397            }
398
399            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
400                .vortex_expect("Uncompressed size must fit in usize");
401            let frame_n_values = if frame_meta.n_values == 0 {
402                // possibly older primitive-only metadata that just didn't store this
403                frame_uncompressed_size / byte_width
404            } else {
405                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
406            };
407
408            let value_idx_stop = value_idx_start + frame_n_values;
409            if value_idx_stop > slice_value_idx_start {
410                // we need this frame
411                frames_to_decompress.push(frame);
412                uncompressed_size_to_decompress += frame_uncompressed_size;
413            } else {
414                n_skipped_values += frame_n_values;
415            }
416            value_idx_start = value_idx_stop;
417        }
418
419        // then we actually decompress those frames
420        let mut decompressor = if let Some(dictionary) = &self.dictionary {
421            zstd::bulk::Decompressor::with_dictionary(dictionary)
422        } else {
423            zstd::bulk::Decompressor::new()
424        }
425        .vortex_expect("Decompressor encountered io error");
426        let mut decompressed = ByteBufferMut::with_capacity_aligned(
427            uncompressed_size_to_decompress,
428            Alignment::new(byte_width),
429        );
430        unsafe {
431            // safety: we immediately fill all bytes in the following loop,
432            // assuming our metadata's uncompressed size is correct
433            decompressed.set_len(uncompressed_size_to_decompress);
434        }
435        let mut uncompressed_start = 0;
436        for frame in frames_to_decompress {
437            let uncompressed_written = decompressor
438                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
439                .vortex_expect("error while decompressing zstd array");
440            uncompressed_start += uncompressed_written;
441        }
442        if uncompressed_start != uncompressed_size_to_decompress {
443            vortex_panic!(
444                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
445                uncompressed_size_to_decompress,
446                uncompressed_start
447            );
448        }
449
450        let decompressed = decompressed.freeze();
451        // Last, we slice the exact values requested out of the decompressed data.
452        let slice_validity = self
453            .unsliced_validity
454            .slice(self.slice_start..self.slice_stop);
455
456        match &self.dtype {
457            DType::Primitive(..) => {
458                let slice_values_buffer = decompressed.slice(
459                    (slice_value_idx_start - n_skipped_values) * byte_width
460                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
461                );
462                let primitive = PrimitiveArray::from_values_byte_buffer(
463                    slice_values_buffer,
464                    self.dtype.as_ptype(),
465                    slice_validity,
466                    slice_n_rows,
467                );
468
469                primitive.into_array()
470            }
471            DType::Binary(_) | DType::Utf8(_) => {
472                match slice_validity.to_mask(slice_n_rows).indices() {
473                    AllOr::All => {
474                        // the decompressed buffer is a bunch of interleaved u32 lengths
475                        // and strings of those lengths, we need to reconstruct the
476                        // views into those strings by passing through the buffer.
477                        let valid_views = reconstruct_views(decompressed.clone()).slice(
478                            slice_value_idx_start - n_skipped_values
479                                ..slice_value_idx_stop - n_skipped_values,
480                        );
481
482                        // SAFETY: we properly construct the views inside `reconstruct_views`
483                        unsafe {
484                            VarBinViewArray::new_unchecked(
485                                valid_views,
486                                Arc::from([decompressed]),
487                                self.dtype.clone(),
488                                slice_validity,
489                            )
490                        }
491                        .into_array()
492                    }
493                    AllOr::None => {
494                        ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows)
495                            .into_array()
496                    }
497                    AllOr::Some(valid_indices) => {
498                        // the decompressed buffer is a bunch of interleaved u32 lengths
499                        // and strings of those lengths, we need to reconstruct the
500                        // views into those strings by passing through the buffer.
501                        let valid_views = reconstruct_views(decompressed.clone()).slice(
502                            slice_value_idx_start - n_skipped_values
503                                ..slice_value_idx_stop - n_skipped_values,
504                        );
505
506                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
507                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
508                            views[*index] = view
509                        }
510
511                        // SAFETY: we properly construct the views inside `reconstruct_views`
512                        unsafe {
513                            VarBinViewArray::new_unchecked(
514                                views.freeze(),
515                                Arc::from([decompressed]),
516                                self.dtype.clone(),
517                                slice_validity,
518                            )
519                        }
520                        .into_array()
521                    }
522                }
523            }
524            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
525        }
526    }
527
528    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
529        ZstdArray {
530            slice_start: self.slice_start + start,
531            slice_stop: self.slice_start + stop,
532            stats_set: Default::default(),
533            ..self.clone()
534        }
535    }
536
537    pub(crate) fn dtype(&self) -> &DType {
538        &self.dtype
539    }
540
541    pub(crate) fn slice_start(&self) -> usize {
542        self.slice_start
543    }
544
545    pub(crate) fn slice_stop(&self) -> usize {
546        self.slice_stop
547    }
548
549    pub(crate) fn unsliced_n_rows(&self) -> usize {
550        self.unsliced_n_rows
551    }
552}
553
554impl ValiditySliceHelper for ZstdArray {
555    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
556        (&self.unsliced_validity, self.slice_start, self.slice_stop)
557    }
558}
559
560impl ArrayVTable<ZstdVTable> for ZstdVTable {
561    fn len(array: &ZstdArray) -> usize {
562        array.slice_stop - array.slice_start
563    }
564
565    fn dtype(array: &ZstdArray) -> &DType {
566        &array.dtype
567    }
568
569    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
570        array.stats_set.to_ref(array.as_ref())
571    }
572}
573
574impl CanonicalVTable<ZstdVTable> for ZstdVTable {
575    fn canonicalize(array: &ZstdArray) -> Canonical {
576        array.decompress().to_canonical()
577    }
578}
579
580impl OperationsVTable<ZstdVTable> for ZstdVTable {
581    fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
582        array._slice(range.start, range.end).into_array()
583    }
584
585    fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
586        array._slice(index, index + 1).decompress().scalar_at(0)
587    }
588}