vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::Range;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_array::accessor::ArrayAccessor;
11use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinViewArray};
12use vortex_array::compute::filter;
13use vortex_array::stats::{ArrayStats, StatsSetRef};
14use vortex_array::validity::Validity;
15use vortex_array::vtable::{
16    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
17    ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
18};
19use vortex_array::{
20    ArrayEq, ArrayHash, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, Precision,
21    ToCanonical, vtable,
22};
23use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
24use vortex_dtype::DType;
25use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
26use vortex_mask::AllOr;
27use vortex_scalar::Scalar;
28use vortex_vector::binaryview::BinaryView;
29
30use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
31
32// Zstd doesn't support training dictionaries on very few samples.
33const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
34type ViewLen = u32;
35
36// Overall approach here:
37// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
38// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
39// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
40// want somewhat faster access to slices or individual rows, allowing us to only
41// decompress the necessary frames.
42
43// Visually, during decompression, we have an interval of frames we're
44// decompressing and a tighter interval of the slice we actually care about.
45// |=============values (all valid elements)==============|
46// |<-skipped_uncompressed->|----decompressed-------------|
47//                              |------slice-------|
48//                              ^                  ^
49// |<-slice_uncompressed_start->|                  |
50// |<------------slice_uncompressed_stop---------->|
51// We then insert these values to the correct position using a primitive array
52// constructor.
53
54vtable!(Zstd);
55
56impl VTable for ZstdVTable {
57    type Array = ZstdArray;
58    type Encoding = ZstdEncoding;
59
60    type ArrayVTable = Self;
61    type CanonicalVTable = Self;
62    type OperationsVTable = Self;
63    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
64    type VisitorVTable = Self;
65    type ComputeVTable = NotSupported;
66    type EncodeVTable = Self;
67    type SerdeVTable = Self;
68    type OperatorVTable = NotSupported;
69
70    fn id(_encoding: &Self::Encoding) -> EncodingId {
71        EncodingId::new_ref("vortex.zstd")
72    }
73
74    fn encoding(_array: &Self::Array) -> EncodingRef {
75        EncodingRef::new_ref(ZstdEncoding.as_ref())
76    }
77}
78
79#[derive(Clone, Debug)]
80pub struct ZstdEncoding;
81
82#[derive(Clone, Debug)]
83pub struct ZstdArray {
84    pub(crate) dictionary: Option<ByteBuffer>,
85    pub(crate) frames: Vec<ByteBuffer>,
86    pub(crate) metadata: ZstdMetadata,
87    dtype: DType,
88    pub(crate) unsliced_validity: Validity,
89    unsliced_n_rows: usize,
90    stats_set: ArrayStats,
91    slice_start: usize,
92    slice_stop: usize,
93}
94
95struct Frames {
96    dictionary: Option<ByteBuffer>,
97    frames: Vec<ByteBuffer>,
98    frame_metas: Vec<ZstdFrameMetadata>,
99}
100
101fn choose_max_dict_size(uncompressed_size: usize) -> usize {
102    // following recommendations from
103    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
104    // that is, 1/100 the data size, up to 100kB.
105    // It appears that zstd can't train dictionaries with <256 bytes.
106    (uncompressed_size / 100).clamp(256, 100 * 1024)
107}
108
109fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
110    let mask = parray.validity_mask();
111    Ok(filter(&parray.to_array(), &mask)?.to_primitive())
112}
113
114fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
115    let mask = vbv.validity_mask();
116    let buffer_and_value_byte_indices = match mask.bit_buffer() {
117        AllOr::None => (Buffer::empty(), Vec::new()),
118        _ => {
119            let mut buffer = BufferMut::with_capacity(
120                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
121                    + mask.true_count() * size_of::<ViewLen>(),
122            );
123            let mut value_byte_indices = Vec::new();
124            vbv.with_iterator(|iterator| {
125                // by flattening, we should omit nulls
126                for value in iterator.flatten() {
127                    value_byte_indices.push(buffer.len());
128                    // here's where we write the string lengths
129                    buffer
130                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
131                    buffer.extend_from_slice(value);
132                }
133                Ok::<_, VortexError>(())
134            })?;
135            (buffer.freeze(), value_byte_indices)
136        }
137    };
138    Ok(buffer_and_value_byte_indices)
139}
140
141fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
142    let mut res = BufferMut::<BinaryView>::empty();
143    let mut offset = 0;
144    while offset < buffer.len() {
145        let str_len = ViewLen::from_le_bytes(
146            buffer
147                .get(offset..offset + size_of::<ViewLen>())
148                .vortex_expect("corrupted zstd length")
149                .try_into()
150                .vortex_expect("must fit ViewLen size"),
151        ) as usize;
152        offset += size_of::<ViewLen>();
153        let value = &buffer[offset..offset + str_len];
154        res.push(BinaryView::make_view(
155            value,
156            0,
157            u32::try_from(offset).vortex_expect("offset must fit in u32"),
158        ));
159        offset += str_len;
160    }
161    res.freeze()
162}
163
164impl ZstdArray {
165    pub fn new(
166        dictionary: Option<ByteBuffer>,
167        frames: Vec<ByteBuffer>,
168        dtype: DType,
169        metadata: ZstdMetadata,
170        n_rows: usize,
171        validity: Validity,
172    ) -> Self {
173        Self {
174            dictionary,
175            frames,
176            metadata,
177            dtype,
178            unsliced_validity: validity,
179            unsliced_n_rows: n_rows,
180            stats_set: Default::default(),
181            slice_start: 0,
182            slice_stop: n_rows,
183        }
184    }
185
186    fn compress_values(
187        value_bytes: &ByteBuffer,
188        frame_byte_starts: &[usize],
189        level: i32,
190        values_per_frame: usize,
191        n_values: usize,
192    ) -> VortexResult<Frames> {
193        let n_frames = frame_byte_starts.len();
194
195        // Would-be sample sizes if we end up applying zstd dictionary
196        let mut sample_sizes = Vec::with_capacity(n_frames);
197        for i in 0..n_frames {
198            let frame_byte_end = frame_byte_starts
199                .get(i + 1)
200                .copied()
201                .unwrap_or(value_bytes.len());
202            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
203        }
204        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
205
206        let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
207            // no dictionary
208            (None, zstd::bulk::Compressor::new(level)?)
209        } else {
210            // with dictionary
211            let max_dict_size = choose_max_dict_size(value_bytes.len());
212            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
213                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
214
215            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
216            (Some(ByteBuffer::from(dict)), compressor)
217        };
218
219        let mut frame_metas = vec![];
220        let mut frames = vec![];
221        for i in 0..n_frames {
222            let frame_byte_end = frame_byte_starts
223                .get(i + 1)
224                .copied()
225                .unwrap_or(value_bytes.len());
226            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
227            let compressed = compressor
228                .compress(uncompressed)
229                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
230            frame_metas.push(ZstdFrameMetadata {
231                uncompressed_size: uncompressed.len() as u64,
232                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
233            });
234            frames.push(ByteBuffer::from(compressed));
235        }
236
237        Ok(Frames {
238            dictionary,
239            frames,
240            frame_metas,
241        })
242    }
243
244    pub fn from_primitive(
245        parray: &PrimitiveArray,
246        level: i32,
247        values_per_frame: usize,
248    ) -> VortexResult<Self> {
249        let dtype = parray.dtype().clone();
250        let byte_width = parray.ptype().byte_width();
251
252        // We compress only the valid elements.
253        let values = collect_valid_primitive(parray)?;
254        let n_values = values.len();
255        let values_per_frame = if values_per_frame > 0 {
256            values_per_frame
257        } else {
258            n_values
259        };
260
261        let value_bytes = values.byte_buffer();
262        let frame_byte_starts = (0..n_values * byte_width)
263            .step_by(values_per_frame * byte_width)
264            .collect::<Vec<_>>();
265        let Frames {
266            dictionary,
267            frames,
268            frame_metas,
269        } = Self::compress_values(
270            value_bytes,
271            &frame_byte_starts,
272            level,
273            values_per_frame,
274            n_values,
275        )?;
276
277        let metadata = ZstdMetadata {
278            dictionary_size: dictionary
279                .as_ref()
280                .map_or(0, |dict| dict.len())
281                .try_into()?,
282            frames: frame_metas,
283        };
284
285        Ok(ZstdArray::new(
286            dictionary,
287            frames,
288            dtype,
289            metadata,
290            parray.len(),
291            parray.validity().clone(),
292        ))
293    }
294
295    pub fn from_var_bin_view(
296        vbv: &VarBinViewArray,
297        level: i32,
298        values_per_frame: usize,
299    ) -> VortexResult<Self> {
300        // Approach for strings: we prefix each string with its length as a u32.
301        // This is the same as what Parquet does. In some cases it may be better
302        // to separate the binary data and lengths as two separate streams, but
303        // this approach is simpler and can be best in cases when there is
304        // mutual information between strings and their lengths.
305        let dtype = vbv.dtype().clone();
306
307        // We compress only the valid elements.
308        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
309        let n_values = value_byte_indices.len();
310        let values_per_frame = if values_per_frame > 0 {
311            values_per_frame
312        } else {
313            n_values
314        };
315
316        let frame_byte_starts = (0..n_values)
317            .step_by(values_per_frame)
318            .map(|i| value_byte_indices[i])
319            .collect::<Vec<_>>();
320        let Frames {
321            dictionary,
322            frames,
323            frame_metas,
324        } = Self::compress_values(
325            &value_bytes,
326            &frame_byte_starts,
327            level,
328            values_per_frame,
329            n_values,
330        )?;
331
332        let metadata = ZstdMetadata {
333            dictionary_size: dictionary
334                .as_ref()
335                .map_or(0, |dict| dict.len())
336                .try_into()?,
337            frames: frame_metas,
338        };
339        Ok(ZstdArray::new(
340            dictionary,
341            frames,
342            dtype,
343            metadata,
344            vbv.len(),
345            vbv.validity().clone(),
346        ))
347    }
348
349    pub fn from_canonical(
350        canonical: &Canonical,
351        level: i32,
352        values_per_frame: usize,
353    ) -> VortexResult<Option<Self>> {
354        match canonical {
355            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
356                parray,
357                level,
358                values_per_frame,
359            )?)),
360            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
361                vbv,
362                level,
363                values_per_frame,
364            )?)),
365            _ => Ok(None),
366        }
367    }
368
369    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
370        Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
371            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
372    }
373
374    fn byte_width(&self) -> usize {
375        if self.dtype.is_primitive() {
376            self.dtype.as_ptype().byte_width()
377        } else {
378            1
379        }
380    }
381
382    pub fn decompress(&self) -> ArrayRef {
383        // To start, we figure out which frames we need to decompress, and with
384        // what row offset into the first such frame.
385        let byte_width = self.byte_width();
386        let slice_n_rows = self.slice_stop - self.slice_start;
387        let slice_value_indices = self
388            .unsliced_validity
389            .to_mask(self.unsliced_n_rows)
390            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
391
392        let slice_value_idx_start = slice_value_indices[0];
393        let slice_value_idx_stop = slice_value_indices[1];
394
395        let mut frames_to_decompress = vec![];
396        let mut value_idx_start = 0;
397        let mut uncompressed_size_to_decompress = 0;
398        let mut n_skipped_values = 0;
399        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
400            if value_idx_start >= slice_value_idx_stop {
401                break;
402            }
403
404            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
405                .vortex_expect("Uncompressed size must fit in usize");
406            let frame_n_values = if frame_meta.n_values == 0 {
407                // possibly older primitive-only metadata that just didn't store this
408                frame_uncompressed_size / byte_width
409            } else {
410                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
411            };
412
413            let value_idx_stop = value_idx_start + frame_n_values;
414            if value_idx_stop > slice_value_idx_start {
415                // we need this frame
416                frames_to_decompress.push(frame);
417                uncompressed_size_to_decompress += frame_uncompressed_size;
418            } else {
419                n_skipped_values += frame_n_values;
420            }
421            value_idx_start = value_idx_stop;
422        }
423
424        // then we actually decompress those frames
425        let mut decompressor = if let Some(dictionary) = &self.dictionary {
426            zstd::bulk::Decompressor::with_dictionary(dictionary)
427        } else {
428            zstd::bulk::Decompressor::new()
429        }
430        .vortex_expect("Decompressor encountered io error");
431        let mut decompressed = ByteBufferMut::with_capacity_aligned(
432            uncompressed_size_to_decompress,
433            Alignment::new(byte_width),
434        );
435        unsafe {
436            // safety: we immediately fill all bytes in the following loop,
437            // assuming our metadata's uncompressed size is correct
438            decompressed.set_len(uncompressed_size_to_decompress);
439        }
440        let mut uncompressed_start = 0;
441        for frame in frames_to_decompress {
442            let uncompressed_written = decompressor
443                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
444                .vortex_expect("error while decompressing zstd array");
445            uncompressed_start += uncompressed_written;
446        }
447        if uncompressed_start != uncompressed_size_to_decompress {
448            vortex_panic!(
449                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
450                uncompressed_size_to_decompress,
451                uncompressed_start
452            );
453        }
454
455        let decompressed = decompressed.freeze();
456        // Last, we slice the exact values requested out of the decompressed data.
457        let slice_validity = self
458            .unsliced_validity
459            .slice(self.slice_start..self.slice_stop);
460
461        match &self.dtype {
462            DType::Primitive(..) => {
463                let slice_values_buffer = decompressed.slice(
464                    (slice_value_idx_start - n_skipped_values) * byte_width
465                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
466                );
467                let primitive = PrimitiveArray::from_values_byte_buffer(
468                    slice_values_buffer,
469                    self.dtype.as_ptype(),
470                    slice_validity,
471                    slice_n_rows,
472                );
473
474                primitive.into_array()
475            }
476            DType::Binary(_) | DType::Utf8(_) => {
477                match slice_validity.to_mask(slice_n_rows).indices() {
478                    AllOr::All => {
479                        // the decompressed buffer is a bunch of interleaved u32 lengths
480                        // and strings of those lengths, we need to reconstruct the
481                        // views into those strings by passing through the buffer.
482                        let valid_views = reconstruct_views(decompressed.clone()).slice(
483                            slice_value_idx_start - n_skipped_values
484                                ..slice_value_idx_stop - n_skipped_values,
485                        );
486
487                        // SAFETY: we properly construct the views inside `reconstruct_views`
488                        unsafe {
489                            VarBinViewArray::new_unchecked(
490                                valid_views,
491                                Arc::from([decompressed]),
492                                self.dtype.clone(),
493                                slice_validity,
494                            )
495                        }
496                        .into_array()
497                    }
498                    AllOr::None => {
499                        ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows)
500                            .into_array()
501                    }
502                    AllOr::Some(valid_indices) => {
503                        // the decompressed buffer is a bunch of interleaved u32 lengths
504                        // and strings of those lengths, we need to reconstruct the
505                        // views into those strings by passing through the buffer.
506                        let valid_views = reconstruct_views(decompressed.clone()).slice(
507                            slice_value_idx_start - n_skipped_values
508                                ..slice_value_idx_stop - n_skipped_values,
509                        );
510
511                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
512                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
513                            views[*index] = view
514                        }
515
516                        // SAFETY: we properly construct the views inside `reconstruct_views`
517                        unsafe {
518                            VarBinViewArray::new_unchecked(
519                                views.freeze(),
520                                Arc::from([decompressed]),
521                                self.dtype.clone(),
522                                slice_validity,
523                            )
524                        }
525                        .into_array()
526                    }
527                }
528            }
529            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
530        }
531    }
532
533    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
534        ZstdArray {
535            slice_start: self.slice_start + start,
536            slice_stop: self.slice_start + stop,
537            stats_set: Default::default(),
538            ..self.clone()
539        }
540    }
541
542    pub(crate) fn dtype(&self) -> &DType {
543        &self.dtype
544    }
545
546    pub(crate) fn slice_start(&self) -> usize {
547        self.slice_start
548    }
549
550    pub(crate) fn slice_stop(&self) -> usize {
551        self.slice_stop
552    }
553
554    pub(crate) fn unsliced_n_rows(&self) -> usize {
555        self.unsliced_n_rows
556    }
557}
558
559impl ValiditySliceHelper for ZstdArray {
560    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
561        (&self.unsliced_validity, self.slice_start, self.slice_stop)
562    }
563}
564
565impl ArrayVTable<ZstdVTable> for ZstdVTable {
566    fn len(array: &ZstdArray) -> usize {
567        array.slice_stop - array.slice_start
568    }
569
570    fn dtype(array: &ZstdArray) -> &DType {
571        &array.dtype
572    }
573
574    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
575        array.stats_set.to_ref(array.as_ref())
576    }
577
578    fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
579        match &array.dictionary {
580            Some(dict) => {
581                true.hash(state);
582                dict.array_hash(state, precision);
583            }
584            None => {
585                false.hash(state);
586            }
587        }
588        for frame in &array.frames {
589            frame.array_hash(state, precision);
590        }
591        array.dtype.hash(state);
592        array.unsliced_validity.array_hash(state, precision);
593        array.unsliced_n_rows.hash(state);
594        array.slice_start.hash(state);
595        array.slice_stop.hash(state);
596    }
597
598    fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
599        if !match (&array.dictionary, &other.dictionary) {
600            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
601            (None, None) => true,
602            _ => false,
603        } {
604            return false;
605        }
606        if array.frames.len() != other.frames.len() {
607            return false;
608        }
609        for (a, b) in array.frames.iter().zip(&other.frames) {
610            if !a.array_eq(b, precision) {
611                return false;
612            }
613        }
614        array.dtype == other.dtype
615            && array
616                .unsliced_validity
617                .array_eq(&other.unsliced_validity, precision)
618            && array.unsliced_n_rows == other.unsliced_n_rows
619            && array.slice_start == other.slice_start
620            && array.slice_stop == other.slice_stop
621    }
622}
623
624impl CanonicalVTable<ZstdVTable> for ZstdVTable {
625    fn canonicalize(array: &ZstdArray) -> Canonical {
626        array.decompress().to_canonical()
627    }
628}
629
630impl OperationsVTable<ZstdVTable> for ZstdVTable {
631    fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
632        array._slice(range.start, range.end).into_array()
633    }
634
635    fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
636        array._slice(index, index + 1).decompress().scalar_at(0)
637    }
638}