vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::ops::Range;
6use std::sync::Arc;
7
8use vortex_array::accessor::ArrayAccessor;
9use vortex_array::arrays::{BinaryView, PrimitiveArray, VarBinViewArray};
10use vortex_array::compute::filter;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::validity::Validity;
13use vortex_array::vtable::{
14    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
15    ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
16};
17use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
18use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
19use vortex_dtype::DType;
20use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
21use vortex_mask::AllOr;
22use vortex_scalar::Scalar;
23
24use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
25
26// Zstd doesn't support training dictionaries on very few samples.
27const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
28type ViewLen = u32;
29
30// Overall approach here:
31// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
32// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
33// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
34// want somewhat faster access to slices or individual rows, allowing us to only
35// decompress the necessary frames.
36
37// Visually, during decompression, we have an interval of frames we're
38// decompressing and a tighter interval of the slice we actually care about.
39// |=============values (all valid elements)==============|
40// |<-skipped_uncompressed->|----decompressed-------------|
41//                              |------slice-------|
42//                              ^                  ^
43// |<-slice_uncompressed_start->|                  |
44// |<------------slice_uncompressed_stop---------->|
45// We then insert these values to the correct position using a primitive array
46// constructor.
47
48vtable!(Zstd);
49
50impl VTable for ZstdVTable {
51    type Array = ZstdArray;
52    type Encoding = ZstdEncoding;
53
54    type ArrayVTable = Self;
55    type CanonicalVTable = Self;
56    type OperationsVTable = Self;
57    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
58    type VisitorVTable = Self;
59    type ComputeVTable = NotSupported;
60    type EncodeVTable = Self;
61    type SerdeVTable = Self;
62    type PipelineVTable = NotSupported;
63
64    fn id(_encoding: &Self::Encoding) -> EncodingId {
65        EncodingId::new_ref("vortex.zstd")
66    }
67
68    fn encoding(_array: &Self::Array) -> EncodingRef {
69        EncodingRef::new_ref(ZstdEncoding.as_ref())
70    }
71}
72
73#[derive(Clone, Debug)]
74pub struct ZstdEncoding;
75
76#[derive(Clone, Debug)]
77pub struct ZstdArray {
78    pub(crate) dictionary: Option<ByteBuffer>,
79    pub(crate) frames: Vec<ByteBuffer>,
80    pub(crate) metadata: ZstdMetadata,
81    dtype: DType,
82    pub(crate) unsliced_validity: Validity,
83    unsliced_n_rows: usize,
84    stats_set: ArrayStats,
85    slice_start: usize,
86    slice_stop: usize,
87}
88
89struct Frames {
90    dictionary: Option<ByteBuffer>,
91    frames: Vec<ByteBuffer>,
92    frame_metas: Vec<ZstdFrameMetadata>,
93}
94
95fn choose_max_dict_size(uncompressed_size: usize) -> usize {
96    // following recommendations from
97    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
98    // that is, 1/100 the data size, up to 100kB.
99    // It appears that zstd can't train dictionaries with <256 bytes.
100    (uncompressed_size / 100).clamp(256, 100 * 1024)
101}
102
103fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
104    let mask = parray.validity_mask();
105    Ok(filter(&parray.to_array(), &mask)?.to_primitive())
106}
107
108fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
109    let mask = vbv.validity_mask();
110    let buffer_and_value_byte_indices = match mask.boolean_buffer() {
111        AllOr::None => (Buffer::empty(), Vec::new()),
112        _ => {
113            let mut buffer = BufferMut::with_capacity(
114                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
115                    + mask.true_count() * size_of::<ViewLen>(),
116            );
117            let mut value_byte_indices = Vec::new();
118            vbv.with_iterator(|iterator| {
119                // by flattening, we should omit nulls
120                for value in iterator.flatten() {
121                    value_byte_indices.push(buffer.len());
122                    // here's where we write the string lengths
123                    buffer
124                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
125                    buffer.extend_from_slice(value);
126                }
127                Ok::<_, VortexError>(())
128            })??;
129            (buffer.freeze(), value_byte_indices)
130        }
131    };
132    Ok(buffer_and_value_byte_indices)
133}
134
135fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
136    let mut res = BufferMut::<BinaryView>::empty();
137    let mut offset = 0;
138    while offset < buffer.len() {
139        let str_len = ViewLen::from_le_bytes(
140            buffer
141                .get(offset..offset + size_of::<ViewLen>())
142                .vortex_expect("corrupted zstd length")
143                .try_into()
144                .vortex_expect("must fit ViewLen size"),
145        ) as usize;
146        offset += size_of::<ViewLen>();
147        let value = &buffer[offset..offset + str_len];
148        res.push(BinaryView::make_view(
149            value,
150            0,
151            u32::try_from(offset).vortex_expect("offset must fit in u32"),
152        ));
153        offset += str_len;
154    }
155    res.freeze()
156}
157
158impl ZstdArray {
159    pub fn new(
160        dictionary: Option<ByteBuffer>,
161        frames: Vec<ByteBuffer>,
162        dtype: DType,
163        metadata: ZstdMetadata,
164        n_rows: usize,
165        validity: Validity,
166    ) -> Self {
167        Self {
168            dictionary,
169            frames,
170            metadata,
171            dtype,
172            unsliced_validity: validity,
173            unsliced_n_rows: n_rows,
174            stats_set: Default::default(),
175            slice_start: 0,
176            slice_stop: n_rows,
177        }
178    }
179
180    fn compress_values(
181        value_bytes: &ByteBuffer,
182        frame_byte_starts: &[usize],
183        level: i32,
184        values_per_frame: usize,
185        n_values: usize,
186    ) -> VortexResult<Frames> {
187        let n_frames = frame_byte_starts.len();
188
189        // Would-be sample sizes if we end up applying zstd dictionary
190        let mut sample_sizes = Vec::with_capacity(n_frames);
191        for i in 0..n_frames {
192            let frame_byte_end = frame_byte_starts
193                .get(i + 1)
194                .copied()
195                .unwrap_or(value_bytes.len());
196            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
197        }
198        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
199
200        let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
201            // no dictionary
202            (None, zstd::bulk::Compressor::new(level)?)
203        } else {
204            // with dictionary
205            let max_dict_size = choose_max_dict_size(value_bytes.len());
206            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
207                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
208
209            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
210            (Some(ByteBuffer::from(dict)), compressor)
211        };
212
213        let mut frame_metas = vec![];
214        let mut frames = vec![];
215        for i in 0..n_frames {
216            let frame_byte_end = frame_byte_starts
217                .get(i + 1)
218                .copied()
219                .unwrap_or(value_bytes.len());
220            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
221            let compressed = compressor
222                .compress(uncompressed)
223                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
224            frame_metas.push(ZstdFrameMetadata {
225                uncompressed_size: uncompressed.len() as u64,
226                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
227            });
228            frames.push(ByteBuffer::from(compressed));
229        }
230
231        Ok(Frames {
232            dictionary,
233            frames,
234            frame_metas,
235        })
236    }
237
238    pub fn from_primitive(
239        parray: &PrimitiveArray,
240        level: i32,
241        values_per_frame: usize,
242    ) -> VortexResult<Self> {
243        let dtype = parray.dtype().clone();
244        let byte_width = parray.ptype().byte_width();
245
246        // We compress only the valid elements.
247        let values = collect_valid_primitive(parray)?;
248        let n_values = values.len();
249        let values_per_frame = if values_per_frame > 0 {
250            values_per_frame
251        } else {
252            n_values
253        };
254
255        let value_bytes = values.byte_buffer();
256        let frame_byte_starts = (0..n_values * byte_width)
257            .step_by(values_per_frame * byte_width)
258            .collect::<Vec<_>>();
259        let Frames {
260            dictionary,
261            frames,
262            frame_metas,
263        } = Self::compress_values(
264            value_bytes,
265            &frame_byte_starts,
266            level,
267            values_per_frame,
268            n_values,
269        )?;
270
271        let metadata = ZstdMetadata {
272            dictionary_size: dictionary
273                .as_ref()
274                .map_or(0, |dict| dict.len())
275                .try_into()?,
276            frames: frame_metas,
277        };
278
279        Ok(ZstdArray::new(
280            dictionary,
281            frames,
282            dtype,
283            metadata,
284            parray.len(),
285            parray.validity().clone(),
286        ))
287    }
288
289    pub fn from_var_bin_view(
290        vbv: &VarBinViewArray,
291        level: i32,
292        values_per_frame: usize,
293    ) -> VortexResult<Self> {
294        // Approach for strings: we prefix each string with its length as a u32.
295        // This is the same as what Parquet does. In some cases it may be better
296        // to separate the binary data and lengths as two separate streams, but
297        // this approach is simpler and can be best in cases when there is
298        // mutual information between strings and their lengths.
299        let dtype = vbv.dtype().clone();
300
301        // We compress only the valid elements.
302        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
303        let n_values = value_byte_indices.len();
304        let values_per_frame = if values_per_frame > 0 {
305            values_per_frame
306        } else {
307            n_values
308        };
309
310        let frame_byte_starts = (0..n_values)
311            .step_by(values_per_frame)
312            .map(|i| value_byte_indices[i])
313            .collect::<Vec<_>>();
314        let Frames {
315            dictionary,
316            frames,
317            frame_metas,
318        } = Self::compress_values(
319            &value_bytes,
320            &frame_byte_starts,
321            level,
322            values_per_frame,
323            n_values,
324        )?;
325
326        let metadata = ZstdMetadata {
327            dictionary_size: dictionary
328                .as_ref()
329                .map_or(0, |dict| dict.len())
330                .try_into()?,
331            frames: frame_metas,
332        };
333        Ok(ZstdArray::new(
334            dictionary,
335            frames,
336            dtype,
337            metadata,
338            vbv.len(),
339            vbv.validity().clone(),
340        ))
341    }
342
343    pub fn from_canonical(
344        canonical: &Canonical,
345        level: i32,
346        values_per_frame: usize,
347    ) -> VortexResult<Option<Self>> {
348        match canonical {
349            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
350                parray,
351                level,
352                values_per_frame,
353            )?)),
354            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
355                vbv,
356                level,
357                values_per_frame,
358            )?)),
359            _ => Ok(None),
360        }
361    }
362
363    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
364        Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
365            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
366    }
367
368    fn byte_width(&self) -> usize {
369        if self.dtype.is_primitive() {
370            self.dtype.as_ptype().byte_width()
371        } else {
372            1
373        }
374    }
375
376    pub fn decompress(&self) -> ArrayRef {
377        // To start, we figure out which frames we need to decompress, and with
378        // what row offset into the first such frame.
379        let byte_width = self.byte_width();
380        let slice_n_rows = self.slice_stop - self.slice_start;
381        let slice_value_indices = self
382            .unsliced_validity
383            .to_mask(self.unsliced_n_rows)
384            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
385
386        let slice_value_idx_start = slice_value_indices[0];
387        let slice_value_idx_stop = slice_value_indices[1];
388
389        let mut frames_to_decompress = vec![];
390        let mut value_idx_start = 0;
391        let mut uncompressed_size_to_decompress = 0;
392        let mut n_skipped_values = 0;
393        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
394            if value_idx_start >= slice_value_idx_stop {
395                break;
396            }
397
398            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
399                .vortex_expect("Uncompressed size must fit in usize");
400            let frame_n_values = if frame_meta.n_values == 0 {
401                // possibly older primitive-only metadata that just didn't store this
402                frame_uncompressed_size / byte_width
403            } else {
404                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
405            };
406
407            let value_idx_stop = value_idx_start + frame_n_values;
408            if value_idx_stop > slice_value_idx_start {
409                // we need this frame
410                frames_to_decompress.push(frame);
411                uncompressed_size_to_decompress += frame_uncompressed_size;
412            } else {
413                n_skipped_values += frame_n_values;
414            }
415            value_idx_start = value_idx_stop;
416        }
417
418        // then we actually decompress those frames
419        let mut decompressor = if let Some(dictionary) = &self.dictionary {
420            zstd::bulk::Decompressor::with_dictionary(dictionary)
421        } else {
422            zstd::bulk::Decompressor::new()
423        }
424        .vortex_expect("Decompressor encountered io error");
425        let mut decompressed = ByteBufferMut::with_capacity_aligned(
426            uncompressed_size_to_decompress,
427            Alignment::new(byte_width),
428        );
429        unsafe {
430            // safety: we immediately fill all bytes in the following loop,
431            // assuming our metadata's uncompressed size is correct
432            decompressed.set_len(uncompressed_size_to_decompress);
433        }
434        let mut uncompressed_start = 0;
435        for frame in frames_to_decompress {
436            let uncompressed_written = decompressor
437                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
438                .vortex_expect("error while decompressing zstd array");
439            uncompressed_start += uncompressed_written;
440        }
441        if uncompressed_start != uncompressed_size_to_decompress {
442            vortex_panic!(
443                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
444                uncompressed_size_to_decompress,
445                uncompressed_start
446            );
447        }
448
449        let decompressed = decompressed.freeze();
450        // Last, we slice the exact values requested out of the decompressed data.
451        let slice_validity = self
452            .unsliced_validity
453            .slice(self.slice_start..self.slice_stop);
454
455        match &self.dtype {
456            DType::Primitive(..) => {
457                let slice_values_buffer = decompressed.slice(
458                    (slice_value_idx_start - n_skipped_values) * byte_width
459                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
460                );
461                let primitive = PrimitiveArray::from_values_byte_buffer(
462                    slice_values_buffer,
463                    self.dtype.as_ptype(),
464                    slice_validity,
465                    slice_n_rows,
466                );
467
468                primitive.into_array()
469            }
470            DType::Binary(_) | DType::Utf8(_) => {
471                // The decompressed buffer is a bunch of interleaved u32 lengths
472                // and strings of those lengths, we need to reconstruct the
473                // views into those strings by passing through the buffer.
474                let views = reconstruct_views(decompressed.clone()).slice(
475                    slice_value_idx_start - n_skipped_values
476                        ..slice_value_idx_stop - n_skipped_values,
477                );
478
479                // SAFETY: we properly construct the views inside `reconstruct_views`
480                let vbv = unsafe {
481                    VarBinViewArray::new_unchecked(
482                        views,
483                        Arc::from([decompressed]),
484                        self.dtype.clone(),
485                        slice_validity,
486                    )
487                };
488                vbv.into_array()
489            }
490            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
491        }
492    }
493
494    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
495        ZstdArray {
496            slice_start: self.slice_start + start,
497            slice_stop: self.slice_start + stop,
498            stats_set: Default::default(),
499            ..self.clone()
500        }
501    }
502
503    pub(crate) fn dtype(&self) -> &DType {
504        &self.dtype
505    }
506
507    pub(crate) fn slice_start(&self) -> usize {
508        self.slice_start
509    }
510
511    pub(crate) fn slice_stop(&self) -> usize {
512        self.slice_stop
513    }
514
515    pub(crate) fn unsliced_n_rows(&self) -> usize {
516        self.unsliced_n_rows
517    }
518}
519
520impl ValiditySliceHelper for ZstdArray {
521    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
522        (&self.unsliced_validity, self.slice_start, self.slice_stop)
523    }
524}
525
526impl ArrayVTable<ZstdVTable> for ZstdVTable {
527    fn len(array: &ZstdArray) -> usize {
528        array.slice_stop - array.slice_start
529    }
530
531    fn dtype(array: &ZstdArray) -> &DType {
532        &array.dtype
533    }
534
535    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
536        array.stats_set.to_ref(array.as_ref())
537    }
538}
539
540impl CanonicalVTable<ZstdVTable> for ZstdVTable {
541    fn canonicalize(array: &ZstdArray) -> Canonical {
542        array.decompress().to_canonical()
543    }
544}
545
546impl OperationsVTable<ZstdVTable> for ZstdVTable {
547    fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
548        array._slice(range.start, range.end).into_array()
549    }
550
551    fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
552        array._slice(index, index + 1).decompress().scalar_at(0)
553    }
554}