Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::DynArray;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionStep;
17use vortex_array::IntoArray;
18use vortex_array::LEGACY_SESSION;
19use vortex_array::Precision;
20use vortex_array::ProstMetadata;
21use vortex_array::ToCanonical;
22use vortex_array::VortexSessionExecute;
23use vortex_array::accessor::ArrayAccessor;
24use vortex_array::arrays::ConstantArray;
25use vortex_array::arrays::PrimitiveArray;
26use vortex_array::arrays::VarBinViewArray;
27use vortex_array::arrays::varbinview::build_views::BinaryView;
28use vortex_array::buffer::BufferHandle;
29use vortex_array::dtype::DType;
30use vortex_array::scalar::Scalar;
31use vortex_array::serde::ArrayChildren;
32use vortex_array::stats::ArrayStats;
33use vortex_array::stats::StatsSetRef;
34use vortex_array::validity::Validity;
35use vortex_array::vtable;
36use vortex_array::vtable::ArrayId;
37use vortex_array::vtable::OperationsVTable;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityHelper;
40use vortex_array::vtable::ValiditySliceHelper;
41use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
42use vortex_array::vtable::validity_nchildren;
43use vortex_array::vtable::validity_to_child;
44use vortex_buffer::Alignment;
45use vortex_buffer::Buffer;
46use vortex_buffer::BufferMut;
47use vortex_buffer::ByteBuffer;
48use vortex_buffer::ByteBufferMut;
49use vortex_error::VortexError;
50use vortex_error::VortexExpect;
51use vortex_error::VortexResult;
52use vortex_error::vortex_bail;
53use vortex_error::vortex_ensure;
54use vortex_error::vortex_err;
55use vortex_error::vortex_panic;
56use vortex_mask::AllOr;
57use vortex_session::VortexSession;
58
59use crate::ZstdFrameMetadata;
60use crate::ZstdMetadata;
61
62// Zstd doesn't support training dictionaries on very few samples.
63const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
64type ViewLen = u32;
65
66// Overall approach here:
67// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
68// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
69// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
70// want somewhat faster access to slices or individual rows, allowing us to only
71// decompress the necessary frames.
72
73// Visually, during decompression, we have an interval of frames we're
74// decompressing and a tighter interval of the slice we actually care about.
75// |=============values (all valid elements)==============|
76// |<-skipped_uncompressed->|----decompressed-------------|
77//                              |------slice-------|
78//                              ^                  ^
79// |<-slice_uncompressed_start->|                  |
80// |<------------slice_uncompressed_stop---------->|
81// We then insert these values to the correct position using a primitive array
82// constructor.
83
84vtable!(Zstd);
85
86impl VTable for Zstd {
87    type Array = ZstdArray;
88
89    type Metadata = ProstMetadata<ZstdMetadata>;
90    type OperationsVTable = Self;
91    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
92
93    fn id(_array: &Self::Array) -> ArrayId {
94        Self::ID
95    }
96
97    fn len(array: &ZstdArray) -> usize {
98        array.slice_stop - array.slice_start
99    }
100
101    fn dtype(array: &ZstdArray) -> &DType {
102        &array.dtype
103    }
104
105    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
106        array.stats_set.to_ref(array.as_ref())
107    }
108
109    fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
110        match &array.dictionary {
111            Some(dict) => {
112                true.hash(state);
113                dict.array_hash(state, precision);
114            }
115            None => {
116                false.hash(state);
117            }
118        }
119        for frame in &array.frames {
120            frame.array_hash(state, precision);
121        }
122        array.dtype.hash(state);
123        array.unsliced_validity.array_hash(state, precision);
124        array.unsliced_n_rows.hash(state);
125        array.slice_start.hash(state);
126        array.slice_stop.hash(state);
127    }
128
129    fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
130        if !match (&array.dictionary, &other.dictionary) {
131            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
132            (None, None) => true,
133            _ => false,
134        } {
135            return false;
136        }
137        if array.frames.len() != other.frames.len() {
138            return false;
139        }
140        for (a, b) in array.frames.iter().zip(&other.frames) {
141            if !a.array_eq(b, precision) {
142                return false;
143            }
144        }
145        array.dtype == other.dtype
146            && array
147                .unsliced_validity
148                .array_eq(&other.unsliced_validity, precision)
149            && array.unsliced_n_rows == other.unsliced_n_rows
150            && array.slice_start == other.slice_start
151            && array.slice_stop == other.slice_stop
152    }
153
154    fn nbuffers(array: &ZstdArray) -> usize {
155        array.dictionary.is_some() as usize + array.frames.len()
156    }
157
158    fn buffer(array: &ZstdArray, idx: usize) -> BufferHandle {
159        if let Some(dict) = &array.dictionary {
160            if idx == 0 {
161                return BufferHandle::new_host(dict.clone());
162            }
163            BufferHandle::new_host(array.frames[idx - 1].clone())
164        } else {
165            BufferHandle::new_host(array.frames[idx].clone())
166        }
167    }
168
169    fn buffer_name(array: &ZstdArray, idx: usize) -> Option<String> {
170        if array.dictionary.is_some() {
171            if idx == 0 {
172                Some("dictionary".to_string())
173            } else {
174                Some(format!("frame_{}", idx - 1))
175            }
176        } else {
177            Some(format!("frame_{idx}"))
178        }
179    }
180
181    fn nchildren(array: &ZstdArray) -> usize {
182        validity_nchildren(&array.unsliced_validity)
183    }
184
185    fn child(array: &ZstdArray, idx: usize) -> ArrayRef {
186        validity_to_child(&array.unsliced_validity, array.unsliced_n_rows)
187            .unwrap_or_else(|| vortex_panic!("ZstdArray child index {idx} out of bounds"))
188    }
189
190    fn child_name(_array: &ZstdArray, idx: usize) -> String {
191        match idx {
192            0 => "validity".to_string(),
193            _ => vortex_panic!("ZstdArray child_name index {idx} out of bounds"),
194        }
195    }
196
197    fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
198        Ok(ProstMetadata(array.metadata.clone()))
199    }
200
201    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
202        Ok(Some(metadata.0.encode_to_vec()))
203    }
204
205    fn deserialize(
206        bytes: &[u8],
207        _dtype: &DType,
208        _len: usize,
209        _buffers: &[BufferHandle],
210        _session: &VortexSession,
211    ) -> VortexResult<Self::Metadata> {
212        Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
213    }
214
215    fn build(
216        dtype: &DType,
217        len: usize,
218        metadata: &Self::Metadata,
219        buffers: &[BufferHandle],
220        children: &dyn ArrayChildren,
221    ) -> VortexResult<ZstdArray> {
222        let validity = if children.is_empty() {
223            Validity::from(dtype.nullability())
224        } else if children.len() == 1 {
225            let validity = children.get(0, &Validity::DTYPE, len)?;
226            Validity::Array(validity)
227        } else {
228            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
229        };
230
231        let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
232            // no dictionary
233            (
234                None,
235                buffers
236                    .iter()
237                    .map(|b| b.clone().try_to_host_sync())
238                    .collect::<VortexResult<Vec<_>>>()?,
239            )
240        } else {
241            // with dictionary
242            (
243                Some(buffers[0].clone().try_to_host_sync()?),
244                buffers[1..]
245                    .iter()
246                    .map(|b| b.clone().try_to_host_sync())
247                    .collect::<VortexResult<Vec<_>>>()?,
248            )
249        };
250
251        Ok(ZstdArray::new(
252            dictionary_buffer,
253            compressed_buffers,
254            dtype.clone(),
255            metadata.0.clone(),
256            len,
257            validity,
258        ))
259    }
260
261    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
262        vortex_ensure!(
263            children.len() <= 1,
264            "ZstdArray expects at most 1 child (validity), got {}",
265            children.len()
266        );
267
268        array.unsliced_validity = if children.is_empty() {
269            Validity::from(array.dtype.nullability())
270        } else {
271            Validity::Array(children.into_iter().next().vortex_expect("checked"))
272        };
273
274        Ok(())
275    }
276
277    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
278        array
279            .decompress(ctx)?
280            .execute::<ArrayRef>(ctx)
281            .map(ExecutionStep::Done)
282    }
283
284    fn reduce_parent(
285        array: &Self::Array,
286        parent: &ArrayRef,
287        child_idx: usize,
288    ) -> VortexResult<Option<ArrayRef>> {
289        crate::rules::RULES.evaluate(array, parent, child_idx)
290    }
291}
292
293#[derive(Debug)]
294pub struct Zstd;
295
296impl Zstd {
297    pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
298}
299
300#[derive(Clone, Debug)]
301pub struct ZstdArray {
302    pub(crate) dictionary: Option<ByteBuffer>,
303    pub(crate) frames: Vec<ByteBuffer>,
304    pub(crate) metadata: ZstdMetadata,
305    dtype: DType,
306    pub(crate) unsliced_validity: Validity,
307    unsliced_n_rows: usize,
308    stats_set: ArrayStats,
309    slice_start: usize,
310    slice_stop: usize,
311}
312
313/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
314#[derive(Debug)]
315pub struct ZstdArrayParts {
316    /// The optional dictionary used for compression.
317    pub dictionary: Option<ByteBuffer>,
318    /// The compressed frames.
319    pub frames: Vec<ByteBuffer>,
320    /// The compression metadata.
321    pub metadata: ZstdMetadata,
322    /// The data type of the uncompressed array.
323    pub dtype: DType,
324    /// The validity of the uncompressed array.
325    pub validity: Validity,
326    /// The number of rows in the uncompressed array.
327    pub n_rows: usize,
328    /// Slice start offset.
329    pub slice_start: usize,
330    /// Slice stop offset.
331    pub slice_stop: usize,
332}
333
334struct Frames {
335    dictionary: Option<ByteBuffer>,
336    frames: Vec<ByteBuffer>,
337    frame_metas: Vec<ZstdFrameMetadata>,
338}
339
340fn choose_max_dict_size(uncompressed_size: usize) -> usize {
341    // following recommendations from
342    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
343    // that is, 1/100 the data size, up to 100kB.
344    // It appears that zstd can't train dictionaries with <256 bytes.
345    (uncompressed_size / 100).clamp(256, 100 * 1024)
346}
347
348fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
349    let mask = parray.validity_mask()?;
350    Ok(parray.clone().into_array().filter(mask)?.to_primitive())
351}
352
353fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
354    let mask = vbv.validity_mask()?;
355    let buffer_and_value_byte_indices = match mask.bit_buffer() {
356        AllOr::None => (Buffer::empty(), Vec::new()),
357        _ => {
358            let mut buffer = BufferMut::with_capacity(
359                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
360                    + mask.true_count() * size_of::<ViewLen>(),
361            );
362            let mut value_byte_indices = Vec::new();
363            vbv.with_iterator(|iterator| {
364                // by flattening, we should omit nulls
365                for value in iterator.flatten() {
366                    value_byte_indices.push(buffer.len());
367                    // here's where we write the string lengths
368                    buffer
369                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
370                    buffer.extend_from_slice(value);
371                }
372                Ok::<_, VortexError>(())
373            })?;
374            (buffer.freeze(), value_byte_indices)
375        }
376    };
377    Ok(buffer_and_value_byte_indices)
378}
379
380/// Reconstruct BinaryView structs from length-prefixed byte data.
381///
382/// The buffer contains interleaved u32 lengths (little-endian) and string data.
383pub fn reconstruct_views(buffer: &ByteBuffer) -> Buffer<BinaryView> {
384    let mut res = BufferMut::<BinaryView>::empty();
385    let mut offset = 0;
386    while offset < buffer.len() {
387        let str_len = ViewLen::from_le_bytes(
388            buffer
389                .get(offset..offset + size_of::<ViewLen>())
390                .vortex_expect("corrupted zstd length")
391                .try_into()
392                .ok()
393                .vortex_expect("must fit ViewLen size"),
394        ) as usize;
395        offset += size_of::<ViewLen>();
396        let value = &buffer[offset..offset + str_len];
397        res.push(BinaryView::make_view(
398            value,
399            0,
400            u32::try_from(offset).vortex_expect("offset must fit in u32"),
401        ));
402        offset += str_len;
403    }
404    res.freeze()
405}
406
407impl ZstdArray {
408    pub fn new(
409        dictionary: Option<ByteBuffer>,
410        frames: Vec<ByteBuffer>,
411        dtype: DType,
412        metadata: ZstdMetadata,
413        n_rows: usize,
414        validity: Validity,
415    ) -> Self {
416        Self {
417            dictionary,
418            frames,
419            metadata,
420            dtype,
421            unsliced_validity: validity,
422            unsliced_n_rows: n_rows,
423            stats_set: Default::default(),
424            slice_start: 0,
425            slice_stop: n_rows,
426        }
427    }
428
429    fn compress_values(
430        value_bytes: &ByteBuffer,
431        frame_byte_starts: &[usize],
432        level: i32,
433        values_per_frame: usize,
434        n_values: usize,
435        use_dictionary: bool,
436    ) -> VortexResult<Frames> {
437        let n_frames = frame_byte_starts.len();
438
439        // Would-be sample sizes if we end up applying zstd dictionary
440        let mut sample_sizes = Vec::with_capacity(n_frames);
441        for i in 0..n_frames {
442            let frame_byte_end = frame_byte_starts
443                .get(i + 1)
444                .copied()
445                .unwrap_or(value_bytes.len());
446            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
447        }
448        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
449
450        let (dictionary, mut compressor) = if !use_dictionary
451            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
452        {
453            // no dictionary
454            (None, zstd::bulk::Compressor::new(level)?)
455        } else {
456            // with dictionary
457            let max_dict_size = choose_max_dict_size(value_bytes.len());
458            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
459                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
460
461            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
462            (Some(ByteBuffer::from(dict)), compressor)
463        };
464
465        let mut frame_metas = vec![];
466        let mut frames = vec![];
467        for i in 0..n_frames {
468            let frame_byte_end = frame_byte_starts
469                .get(i + 1)
470                .copied()
471                .unwrap_or(value_bytes.len());
472
473            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
474            let compressed = compressor
475                .compress(uncompressed)
476                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
477            frame_metas.push(ZstdFrameMetadata {
478                uncompressed_size: uncompressed.len() as u64,
479                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
480            });
481            frames.push(ByteBuffer::from(compressed));
482        }
483
484        Ok(Frames {
485            dictionary,
486            frames,
487            frame_metas,
488        })
489    }
490
491    /// Creates a ZstdArray from a primitive array.
492    ///
493    /// # Arguments
494    /// * `parray` - The primitive array to compress
495    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
496    /// * `values_per_frame` - Number of values per frame (0 = single frame)
497    pub fn from_primitive(
498        parray: &PrimitiveArray,
499        level: i32,
500        values_per_frame: usize,
501    ) -> VortexResult<Self> {
502        Self::from_primitive_impl(parray, level, values_per_frame, true)
503    }
504
505    /// Creates a ZstdArray from a primitive array without using a dictionary.
506    ///
507    /// This is useful when the compressed data will be decompressed by systems
508    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
509    ///
510    /// Note: Without a dictionary, each frame is compressed independently.
511    /// Dictionaries are trained from sample data from previously seen frames,
512    /// to improve compression ratio.
513    ///
514    /// # Arguments
515    /// * `parray` - The primitive array to compress
516    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
517    /// * `values_per_frame` - Number of values per frame (0 = single frame)
518    pub fn from_primitive_without_dict(
519        parray: &PrimitiveArray,
520        level: i32,
521        values_per_frame: usize,
522    ) -> VortexResult<Self> {
523        Self::from_primitive_impl(parray, level, values_per_frame, false)
524    }
525
526    fn from_primitive_impl(
527        parray: &PrimitiveArray,
528        level: i32,
529        values_per_frame: usize,
530        use_dictionary: bool,
531    ) -> VortexResult<Self> {
532        let dtype = parray.dtype().clone();
533        let byte_width = parray.ptype().byte_width();
534
535        // We compress only the valid elements.
536        let values = collect_valid_primitive(parray)?;
537        let n_values = values.len();
538        let values_per_frame = if values_per_frame > 0 {
539            values_per_frame
540        } else {
541            n_values
542        };
543
544        let value_bytes = values.buffer_handle().try_to_host_sync()?;
545        // Align frames to buffer alignment. This is necessary for overaligned buffers.
546        let alignment = *value_bytes.alignment();
547        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
548
549        let frame_byte_starts = (0..n_values * byte_width)
550            .step_by(step_width)
551            .collect::<Vec<_>>();
552        let Frames {
553            dictionary,
554            frames,
555            frame_metas,
556        } = Self::compress_values(
557            &value_bytes,
558            &frame_byte_starts,
559            level,
560            values_per_frame,
561            n_values,
562            use_dictionary,
563        )?;
564
565        let metadata = ZstdMetadata {
566            dictionary_size: dictionary
567                .as_ref()
568                .map_or(0, |dict| dict.len())
569                .try_into()?,
570            frames: frame_metas,
571        };
572
573        Ok(ZstdArray::new(
574            dictionary,
575            frames,
576            dtype,
577            metadata,
578            parray.len(),
579            parray.validity().clone(),
580        ))
581    }
582
583    /// Creates a ZstdArray from a VarBinView array.
584    ///
585    /// # Arguments
586    /// * `vbv` - The VarBinView array to compress
587    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
588    /// * `values_per_frame` - Number of values per frame (0 = single frame)
589    pub fn from_var_bin_view(
590        vbv: &VarBinViewArray,
591        level: i32,
592        values_per_frame: usize,
593    ) -> VortexResult<Self> {
594        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
595    }
596
597    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
598    ///
599    /// This is useful when the compressed data will be decompressed by systems
600    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
601    ///
602    /// Note: Without a dictionary, each frame is compressed independently.
603    /// Dictionaries are trained from sample data from previously seen frames,
604    /// to improve compression ratio.
605    ///
606    /// # Arguments
607    /// * `vbv` - The VarBinView array to compress
608    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
609    /// * `values_per_frame` - Number of values per frame (0 = single frame)
610    pub fn from_var_bin_view_without_dict(
611        vbv: &VarBinViewArray,
612        level: i32,
613        values_per_frame: usize,
614    ) -> VortexResult<Self> {
615        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
616    }
617
618    fn from_var_bin_view_impl(
619        vbv: &VarBinViewArray,
620        level: i32,
621        values_per_frame: usize,
622        use_dictionary: bool,
623    ) -> VortexResult<Self> {
624        // Approach for strings: we prefix each string with its length as a u32.
625        // This is the same as what Parquet does. In some cases it may be better
626        // to separate the binary data and lengths as two separate streams, but
627        // this approach is simpler and can be best in cases when there is
628        // mutual information between strings and their lengths.
629        let dtype = vbv.dtype().clone();
630
631        // We compress only the valid elements.
632        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
633        let n_values = value_byte_indices.len();
634        let values_per_frame = if values_per_frame > 0 {
635            values_per_frame
636        } else {
637            n_values
638        };
639
640        let frame_byte_starts = (0..n_values)
641            .step_by(values_per_frame)
642            .map(|i| value_byte_indices[i])
643            .collect::<Vec<_>>();
644        let Frames {
645            dictionary,
646            frames,
647            frame_metas,
648        } = Self::compress_values(
649            &value_bytes,
650            &frame_byte_starts,
651            level,
652            values_per_frame,
653            n_values,
654            use_dictionary,
655        )?;
656
657        let metadata = ZstdMetadata {
658            dictionary_size: dictionary
659                .as_ref()
660                .map_or(0, |dict| dict.len())
661                .try_into()?,
662            frames: frame_metas,
663        };
664        Ok(ZstdArray::new(
665            dictionary,
666            frames,
667            dtype,
668            metadata,
669            vbv.len(),
670            vbv.validity().clone(),
671        ))
672    }
673
674    pub fn from_canonical(
675        canonical: &Canonical,
676        level: i32,
677        values_per_frame: usize,
678    ) -> VortexResult<Option<Self>> {
679        match canonical {
680            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
681                parray,
682                level,
683                values_per_frame,
684            )?)),
685            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
686                vbv,
687                level,
688                values_per_frame,
689            )?)),
690            _ => Ok(None),
691        }
692    }
693
694    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
695        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
696            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
697    }
698
699    fn byte_width(&self) -> usize {
700        if self.dtype.is_primitive() {
701            self.dtype.as_ptype().byte_width()
702        } else {
703            1
704        }
705    }
706
707    pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
708        // To start, we figure out which frames we need to decompress, and with
709        // what row offset into the first such frame.
710        let byte_width = self.byte_width();
711        let slice_n_rows = self.slice_stop - self.slice_start;
712        let slice_value_indices = self
713            .unsliced_validity
714            .execute_mask(self.unsliced_n_rows, ctx)?
715            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
716
717        let slice_value_idx_start = slice_value_indices[0];
718        let slice_value_idx_stop = slice_value_indices[1];
719
720        let mut frames_to_decompress = vec![];
721        let mut value_idx_start = 0;
722        let mut uncompressed_size_to_decompress = 0;
723        let mut n_skipped_values = 0;
724        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
725            if value_idx_start >= slice_value_idx_stop {
726                break;
727            }
728
729            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
730                .vortex_expect("Uncompressed size must fit in usize");
731            let frame_n_values = if frame_meta.n_values == 0 {
732                // possibly older primitive-only metadata that just didn't store this
733                frame_uncompressed_size / byte_width
734            } else {
735                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
736            };
737
738            let value_idx_stop = value_idx_start + frame_n_values;
739            if value_idx_stop > slice_value_idx_start {
740                // we need this frame
741                frames_to_decompress.push(frame);
742                uncompressed_size_to_decompress += frame_uncompressed_size;
743            } else {
744                n_skipped_values += frame_n_values;
745            }
746            value_idx_start = value_idx_stop;
747        }
748
749        // then we actually decompress those frames
750        let mut decompressor = if let Some(dictionary) = &self.dictionary {
751            zstd::bulk::Decompressor::with_dictionary(dictionary)?
752        } else {
753            zstd::bulk::Decompressor::new()?
754        };
755        let mut decompressed = ByteBufferMut::with_capacity_aligned(
756            uncompressed_size_to_decompress,
757            Alignment::new(byte_width),
758        );
759        unsafe {
760            // safety: we immediately fill all bytes in the following loop,
761            // assuming our metadata's uncompressed size is correct
762            decompressed.set_len(uncompressed_size_to_decompress);
763        }
764        let mut uncompressed_start = 0;
765        for frame in frames_to_decompress {
766            let uncompressed_written = decompressor
767                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
768            uncompressed_start += uncompressed_written;
769        }
770        if uncompressed_start != uncompressed_size_to_decompress {
771            vortex_panic!(
772                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
773                uncompressed_size_to_decompress,
774                uncompressed_start
775            );
776        }
777
778        let decompressed = decompressed.freeze();
779        // Last, we slice the exact values requested out of the decompressed data.
780        let mut slice_validity = self
781            .unsliced_validity
782            .slice(self.slice_start..self.slice_stop)?;
783
784        // NOTE: this block handles setting the output type when the validity and DType disagree.
785        //
786        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
787        // the data frames. A ZSTD Array that was initialized must always hold onto its full
788        // validity bitmap, even if sliced to only include non-null values.
789        //
790        // We ensure that the validity of the decompressed array ALWAYS matches the validity
791        // implied by the DType.
792        if !self.dtype().is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
793            assert!(
794                matches!(slice_validity, Validity::AllValid),
795                "ZSTD array expects to be non-nullable but there are nulls after decompression"
796            );
797
798            slice_validity = Validity::NonNullable;
799        } else if self.dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
800            slice_validity = Validity::AllValid;
801        }
802        //
803        // END OF IMPORTANT BLOCK
804        //
805
806        match &self.dtype {
807            DType::Primitive(..) => {
808                let slice_values_buffer = decompressed.slice(
809                    (slice_value_idx_start - n_skipped_values) * byte_width
810                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
811                );
812                let primitive = PrimitiveArray::from_values_byte_buffer(
813                    slice_values_buffer,
814                    self.dtype.as_ptype(),
815                    slice_validity,
816                    slice_n_rows,
817                );
818
819                Ok(primitive.into_array())
820            }
821            DType::Binary(_) | DType::Utf8(_) => {
822                match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
823                    AllOr::All => {
824                        // the decompressed buffer is a bunch of interleaved u32 lengths
825                        // and strings of those lengths, we need to reconstruct the
826                        // views into those strings by passing through the buffer.
827                        let valid_views = reconstruct_views(&decompressed).slice(
828                            slice_value_idx_start - n_skipped_values
829                                ..slice_value_idx_stop - n_skipped_values,
830                        );
831
832                        // SAFETY: we properly construct the views inside `reconstruct_views`
833                        Ok(unsafe {
834                            VarBinViewArray::new_unchecked(
835                                valid_views,
836                                Arc::from([decompressed]),
837                                self.dtype.clone(),
838                                slice_validity,
839                            )
840                        }
841                        .into_array())
842                    }
843                    AllOr::None => Ok(ConstantArray::new(
844                        Scalar::null(self.dtype.clone()),
845                        slice_n_rows,
846                    )
847                    .into_array()),
848                    AllOr::Some(valid_indices) => {
849                        // the decompressed buffer is a bunch of interleaved u32 lengths
850                        // and strings of those lengths, we need to reconstruct the
851                        // views into those strings by passing through the buffer.
852                        let valid_views = reconstruct_views(&decompressed).slice(
853                            slice_value_idx_start - n_skipped_values
854                                ..slice_value_idx_stop - n_skipped_values,
855                        );
856
857                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
858                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
859                            views[*index] = view
860                        }
861
862                        // SAFETY: we properly construct the views inside `reconstruct_views`
863                        Ok(unsafe {
864                            VarBinViewArray::new_unchecked(
865                                views.freeze(),
866                                Arc::from([decompressed]),
867                                self.dtype.clone(),
868                                slice_validity,
869                            )
870                        }
871                        .into_array())
872                    }
873                }
874            }
875            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
876        }
877    }
878
879    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
880        let new_start = self.slice_start + start;
881        let new_stop = self.slice_start + stop;
882
883        assert!(
884            new_start <= self.slice_stop,
885            "new slice start {new_start} exceeds end {}",
886            self.slice_stop
887        );
888
889        assert!(
890            new_stop <= self.slice_stop,
891            "new slice stop {new_stop} exceeds end {}",
892            self.slice_stop
893        );
894
895        ZstdArray {
896            slice_start: self.slice_start + start,
897            slice_stop: self.slice_start + stop,
898            stats_set: Default::default(),
899            ..self.clone()
900        }
901    }
902
903    /// Consumes the array and returns its parts.
904    pub fn into_parts(self) -> ZstdArrayParts {
905        ZstdArrayParts {
906            dictionary: self.dictionary,
907            frames: self.frames,
908            metadata: self.metadata,
909            dtype: self.dtype,
910            validity: self.unsliced_validity,
911            n_rows: self.unsliced_n_rows,
912            slice_start: self.slice_start,
913            slice_stop: self.slice_stop,
914        }
915    }
916
917    pub(crate) fn dtype(&self) -> &DType {
918        &self.dtype
919    }
920
921    pub(crate) fn slice_start(&self) -> usize {
922        self.slice_start
923    }
924
925    pub(crate) fn slice_stop(&self) -> usize {
926        self.slice_stop
927    }
928
929    pub(crate) fn unsliced_n_rows(&self) -> usize {
930        self.unsliced_n_rows
931    }
932}
933
934impl ValiditySliceHelper for ZstdArray {
935    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
936        (&self.unsliced_validity, self.slice_start, self.slice_stop)
937    }
938}
939
940impl OperationsVTable<Zstd> for Zstd {
941    fn scalar_at(array: &ZstdArray, index: usize) -> VortexResult<Scalar> {
942        let mut ctx = LEGACY_SESSION.create_execution_ctx();
943        array
944            ._slice(index, index + 1)
945            .decompress(&mut ctx)?
946            .scalar_at(0)
947    }
948}