Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::DynArray;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionStep;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::ProstMetadata;
20use vortex_array::ToCanonical;
21use vortex_array::accessor::ArrayAccessor;
22use vortex_array::arrays::ConstantArray;
23use vortex_array::arrays::PrimitiveArray;
24use vortex_array::arrays::VarBinViewArray;
25use vortex_array::arrays::varbinview::build_views::BinaryView;
26use vortex_array::buffer::BufferHandle;
27use vortex_array::dtype::DType;
28use vortex_array::scalar::Scalar;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::OperationsVTable;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityHelper;
38use vortex_array::vtable::ValiditySliceHelper;
39use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
40use vortex_array::vtable::validity_nchildren;
41use vortex_array::vtable::validity_to_child;
42use vortex_buffer::Alignment;
43use vortex_buffer::Buffer;
44use vortex_buffer::BufferMut;
45use vortex_buffer::ByteBuffer;
46use vortex_buffer::ByteBufferMut;
47use vortex_error::VortexError;
48use vortex_error::VortexExpect;
49use vortex_error::VortexResult;
50use vortex_error::vortex_bail;
51use vortex_error::vortex_ensure;
52use vortex_error::vortex_err;
53use vortex_error::vortex_panic;
54use vortex_mask::AllOr;
55use vortex_session::VortexSession;
56
57use crate::ZstdFrameMetadata;
58use crate::ZstdMetadata;
59
60// Zstd doesn't support training dictionaries on very few samples.
61const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
62type ViewLen = u32;
63
64// Overall approach here:
65// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
66// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
67// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
68// want somewhat faster access to slices or individual rows, allowing us to only
69// decompress the necessary frames.
70
71// Visually, during decompression, we have an interval of frames we're
72// decompressing and a tighter interval of the slice we actually care about.
73// |=============values (all valid elements)==============|
74// |<-skipped_uncompressed->|----decompressed-------------|
75//                              |------slice-------|
76//                              ^                  ^
77// |<-slice_uncompressed_start->|                  |
78// |<------------slice_uncompressed_stop---------->|
79// We then insert these values to the correct position using a primitive array
80// constructor.
81
82vtable!(Zstd);
83
84impl VTable for ZstdVTable {
85    type Array = ZstdArray;
86
87    type Metadata = ProstMetadata<ZstdMetadata>;
88    type OperationsVTable = Self;
89    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
90
91    fn id(_array: &Self::Array) -> ArrayId {
92        Self::ID
93    }
94
95    fn len(array: &ZstdArray) -> usize {
96        array.slice_stop - array.slice_start
97    }
98
99    fn dtype(array: &ZstdArray) -> &DType {
100        &array.dtype
101    }
102
103    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
104        array.stats_set.to_ref(array.as_ref())
105    }
106
107    fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
108        match &array.dictionary {
109            Some(dict) => {
110                true.hash(state);
111                dict.array_hash(state, precision);
112            }
113            None => {
114                false.hash(state);
115            }
116        }
117        for frame in &array.frames {
118            frame.array_hash(state, precision);
119        }
120        array.dtype.hash(state);
121        array.unsliced_validity.array_hash(state, precision);
122        array.unsliced_n_rows.hash(state);
123        array.slice_start.hash(state);
124        array.slice_stop.hash(state);
125    }
126
127    fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
128        if !match (&array.dictionary, &other.dictionary) {
129            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
130            (None, None) => true,
131            _ => false,
132        } {
133            return false;
134        }
135        if array.frames.len() != other.frames.len() {
136            return false;
137        }
138        for (a, b) in array.frames.iter().zip(&other.frames) {
139            if !a.array_eq(b, precision) {
140                return false;
141            }
142        }
143        array.dtype == other.dtype
144            && array
145                .unsliced_validity
146                .array_eq(&other.unsliced_validity, precision)
147            && array.unsliced_n_rows == other.unsliced_n_rows
148            && array.slice_start == other.slice_start
149            && array.slice_stop == other.slice_stop
150    }
151
152    fn nbuffers(array: &ZstdArray) -> usize {
153        array.dictionary.is_some() as usize + array.frames.len()
154    }
155
156    fn buffer(array: &ZstdArray, idx: usize) -> BufferHandle {
157        if let Some(dict) = &array.dictionary {
158            if idx == 0 {
159                return BufferHandle::new_host(dict.clone());
160            }
161            BufferHandle::new_host(array.frames[idx - 1].clone())
162        } else {
163            BufferHandle::new_host(array.frames[idx].clone())
164        }
165    }
166
167    fn buffer_name(array: &ZstdArray, idx: usize) -> Option<String> {
168        if array.dictionary.is_some() {
169            if idx == 0 {
170                Some("dictionary".to_string())
171            } else {
172                Some(format!("frame_{}", idx - 1))
173            }
174        } else {
175            Some(format!("frame_{idx}"))
176        }
177    }
178
179    fn nchildren(array: &ZstdArray) -> usize {
180        validity_nchildren(&array.unsliced_validity)
181    }
182
183    fn child(array: &ZstdArray, idx: usize) -> ArrayRef {
184        validity_to_child(&array.unsliced_validity, array.unsliced_n_rows)
185            .unwrap_or_else(|| vortex_panic!("ZstdArray child index {idx} out of bounds"))
186    }
187
188    fn child_name(_array: &ZstdArray, idx: usize) -> String {
189        match idx {
190            0 => "validity".to_string(),
191            _ => vortex_panic!("ZstdArray child_name index {idx} out of bounds"),
192        }
193    }
194
195    fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
196        Ok(ProstMetadata(array.metadata.clone()))
197    }
198
199    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
200        Ok(Some(metadata.0.encode_to_vec()))
201    }
202
203    fn deserialize(
204        bytes: &[u8],
205        _dtype: &DType,
206        _len: usize,
207        _buffers: &[BufferHandle],
208        _session: &VortexSession,
209    ) -> VortexResult<Self::Metadata> {
210        Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
211    }
212
213    fn build(
214        dtype: &DType,
215        len: usize,
216        metadata: &Self::Metadata,
217        buffers: &[BufferHandle],
218        children: &dyn ArrayChildren,
219    ) -> VortexResult<ZstdArray> {
220        let validity = if children.is_empty() {
221            Validity::from(dtype.nullability())
222        } else if children.len() == 1 {
223            let validity = children.get(0, &Validity::DTYPE, len)?;
224            Validity::Array(validity)
225        } else {
226            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
227        };
228
229        let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
230            // no dictionary
231            (
232                None,
233                buffers
234                    .iter()
235                    .map(|b| b.clone().try_to_host_sync())
236                    .collect::<VortexResult<Vec<_>>>()?,
237            )
238        } else {
239            // with dictionary
240            (
241                Some(buffers[0].clone().try_to_host_sync()?),
242                buffers[1..]
243                    .iter()
244                    .map(|b| b.clone().try_to_host_sync())
245                    .collect::<VortexResult<Vec<_>>>()?,
246            )
247        };
248
249        Ok(ZstdArray::new(
250            dictionary_buffer,
251            compressed_buffers,
252            dtype.clone(),
253            metadata.0.clone(),
254            len,
255            validity,
256        ))
257    }
258
259    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
260        vortex_ensure!(
261            children.len() <= 1,
262            "ZstdArray expects at most 1 child (validity), got {}",
263            children.len()
264        );
265
266        array.unsliced_validity = if children.is_empty() {
267            Validity::from(array.dtype.nullability())
268        } else {
269            Validity::Array(children.into_iter().next().vortex_expect("checked"))
270        };
271
272        Ok(())
273    }
274
275    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
276        array
277            .decompress()?
278            .execute::<ArrayRef>(ctx)
279            .map(ExecutionStep::Done)
280    }
281
282    fn reduce_parent(
283        array: &Self::Array,
284        parent: &ArrayRef,
285        child_idx: usize,
286    ) -> VortexResult<Option<ArrayRef>> {
287        crate::rules::RULES.evaluate(array, parent, child_idx)
288    }
289}
290
291#[derive(Debug)]
292pub struct ZstdVTable;
293
294impl ZstdVTable {
295    pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
296}
297
298#[derive(Clone, Debug)]
299pub struct ZstdArray {
300    pub(crate) dictionary: Option<ByteBuffer>,
301    pub(crate) frames: Vec<ByteBuffer>,
302    pub(crate) metadata: ZstdMetadata,
303    dtype: DType,
304    pub(crate) unsliced_validity: Validity,
305    unsliced_n_rows: usize,
306    stats_set: ArrayStats,
307    slice_start: usize,
308    slice_stop: usize,
309}
310
311/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
312#[derive(Debug)]
313pub struct ZstdArrayParts {
314    /// The optional dictionary used for compression.
315    pub dictionary: Option<ByteBuffer>,
316    /// The compressed frames.
317    pub frames: Vec<ByteBuffer>,
318    /// The compression metadata.
319    pub metadata: ZstdMetadata,
320    /// The data type of the uncompressed array.
321    pub dtype: DType,
322    /// The validity of the uncompressed array.
323    pub validity: Validity,
324    /// The number of rows in the uncompressed array.
325    pub n_rows: usize,
326    /// Slice start offset.
327    pub slice_start: usize,
328    /// Slice stop offset.
329    pub slice_stop: usize,
330}
331
332struct Frames {
333    dictionary: Option<ByteBuffer>,
334    frames: Vec<ByteBuffer>,
335    frame_metas: Vec<ZstdFrameMetadata>,
336}
337
338fn choose_max_dict_size(uncompressed_size: usize) -> usize {
339    // following recommendations from
340    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
341    // that is, 1/100 the data size, up to 100kB.
342    // It appears that zstd can't train dictionaries with <256 bytes.
343    (uncompressed_size / 100).clamp(256, 100 * 1024)
344}
345
346fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
347    let mask = parray.validity_mask()?;
348    Ok(parray.clone().into_array().filter(mask)?.to_primitive())
349}
350
351fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
352    let mask = vbv.validity_mask()?;
353    let buffer_and_value_byte_indices = match mask.bit_buffer() {
354        AllOr::None => (Buffer::empty(), Vec::new()),
355        _ => {
356            let mut buffer = BufferMut::with_capacity(
357                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
358                    + mask.true_count() * size_of::<ViewLen>(),
359            );
360            let mut value_byte_indices = Vec::new();
361            vbv.with_iterator(|iterator| {
362                // by flattening, we should omit nulls
363                for value in iterator.flatten() {
364                    value_byte_indices.push(buffer.len());
365                    // here's where we write the string lengths
366                    buffer
367                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
368                    buffer.extend_from_slice(value);
369                }
370                Ok::<_, VortexError>(())
371            })?;
372            (buffer.freeze(), value_byte_indices)
373        }
374    };
375    Ok(buffer_and_value_byte_indices)
376}
377
378/// Reconstruct BinaryView structs from length-prefixed byte data.
379///
380/// The buffer contains interleaved u32 lengths (little-endian) and string data.
381pub fn reconstruct_views(buffer: &ByteBuffer) -> Buffer<BinaryView> {
382    let mut res = BufferMut::<BinaryView>::empty();
383    let mut offset = 0;
384    while offset < buffer.len() {
385        let str_len = ViewLen::from_le_bytes(
386            buffer
387                .get(offset..offset + size_of::<ViewLen>())
388                .vortex_expect("corrupted zstd length")
389                .try_into()
390                .ok()
391                .vortex_expect("must fit ViewLen size"),
392        ) as usize;
393        offset += size_of::<ViewLen>();
394        let value = &buffer[offset..offset + str_len];
395        res.push(BinaryView::make_view(
396            value,
397            0,
398            u32::try_from(offset).vortex_expect("offset must fit in u32"),
399        ));
400        offset += str_len;
401    }
402    res.freeze()
403}
404
405impl ZstdArray {
406    pub fn new(
407        dictionary: Option<ByteBuffer>,
408        frames: Vec<ByteBuffer>,
409        dtype: DType,
410        metadata: ZstdMetadata,
411        n_rows: usize,
412        validity: Validity,
413    ) -> Self {
414        Self {
415            dictionary,
416            frames,
417            metadata,
418            dtype,
419            unsliced_validity: validity,
420            unsliced_n_rows: n_rows,
421            stats_set: Default::default(),
422            slice_start: 0,
423            slice_stop: n_rows,
424        }
425    }
426
427    fn compress_values(
428        value_bytes: &ByteBuffer,
429        frame_byte_starts: &[usize],
430        level: i32,
431        values_per_frame: usize,
432        n_values: usize,
433        use_dictionary: bool,
434    ) -> VortexResult<Frames> {
435        let n_frames = frame_byte_starts.len();
436
437        // Would-be sample sizes if we end up applying zstd dictionary
438        let mut sample_sizes = Vec::with_capacity(n_frames);
439        for i in 0..n_frames {
440            let frame_byte_end = frame_byte_starts
441                .get(i + 1)
442                .copied()
443                .unwrap_or(value_bytes.len());
444            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
445        }
446        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
447
448        let (dictionary, mut compressor) = if !use_dictionary
449            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
450        {
451            // no dictionary
452            (None, zstd::bulk::Compressor::new(level)?)
453        } else {
454            // with dictionary
455            let max_dict_size = choose_max_dict_size(value_bytes.len());
456            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
457                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
458
459            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
460            (Some(ByteBuffer::from(dict)), compressor)
461        };
462
463        let mut frame_metas = vec![];
464        let mut frames = vec![];
465        for i in 0..n_frames {
466            let frame_byte_end = frame_byte_starts
467                .get(i + 1)
468                .copied()
469                .unwrap_or(value_bytes.len());
470
471            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
472            let compressed = compressor
473                .compress(uncompressed)
474                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
475            frame_metas.push(ZstdFrameMetadata {
476                uncompressed_size: uncompressed.len() as u64,
477                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
478            });
479            frames.push(ByteBuffer::from(compressed));
480        }
481
482        Ok(Frames {
483            dictionary,
484            frames,
485            frame_metas,
486        })
487    }
488
489    /// Creates a ZstdArray from a primitive array.
490    ///
491    /// # Arguments
492    /// * `parray` - The primitive array to compress
493    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
494    /// * `values_per_frame` - Number of values per frame (0 = single frame)
495    pub fn from_primitive(
496        parray: &PrimitiveArray,
497        level: i32,
498        values_per_frame: usize,
499    ) -> VortexResult<Self> {
500        Self::from_primitive_impl(parray, level, values_per_frame, true)
501    }
502
503    /// Creates a ZstdArray from a primitive array without using a dictionary.
504    ///
505    /// This is useful when the compressed data will be decompressed by systems
506    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
507    ///
508    /// Note: Without a dictionary, each frame is compressed independently.
509    /// Dictionaries are trained from sample data from previously seen frames,
510    /// to improve compression ratio.
511    ///
512    /// # Arguments
513    /// * `parray` - The primitive array to compress
514    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
515    /// * `values_per_frame` - Number of values per frame (0 = single frame)
516    pub fn from_primitive_without_dict(
517        parray: &PrimitiveArray,
518        level: i32,
519        values_per_frame: usize,
520    ) -> VortexResult<Self> {
521        Self::from_primitive_impl(parray, level, values_per_frame, false)
522    }
523
524    fn from_primitive_impl(
525        parray: &PrimitiveArray,
526        level: i32,
527        values_per_frame: usize,
528        use_dictionary: bool,
529    ) -> VortexResult<Self> {
530        let dtype = parray.dtype().clone();
531        let byte_width = parray.ptype().byte_width();
532
533        // We compress only the valid elements.
534        let values = collect_valid_primitive(parray)?;
535        let n_values = values.len();
536        let values_per_frame = if values_per_frame > 0 {
537            values_per_frame
538        } else {
539            n_values
540        };
541
542        let value_bytes = values.buffer_handle().try_to_host_sync()?;
543        // Align frames to buffer alignment. This is necessary for overaligned buffers.
544        let alignment = *value_bytes.alignment();
545        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
546
547        let frame_byte_starts = (0..n_values * byte_width)
548            .step_by(step_width)
549            .collect::<Vec<_>>();
550        let Frames {
551            dictionary,
552            frames,
553            frame_metas,
554        } = Self::compress_values(
555            &value_bytes,
556            &frame_byte_starts,
557            level,
558            values_per_frame,
559            n_values,
560            use_dictionary,
561        )?;
562
563        let metadata = ZstdMetadata {
564            dictionary_size: dictionary
565                .as_ref()
566                .map_or(0, |dict| dict.len())
567                .try_into()?,
568            frames: frame_metas,
569        };
570
571        Ok(ZstdArray::new(
572            dictionary,
573            frames,
574            dtype,
575            metadata,
576            parray.len(),
577            parray.validity().clone(),
578        ))
579    }
580
581    /// Creates a ZstdArray from a VarBinView array.
582    ///
583    /// # Arguments
584    /// * `vbv` - The VarBinView array to compress
585    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
586    /// * `values_per_frame` - Number of values per frame (0 = single frame)
587    pub fn from_var_bin_view(
588        vbv: &VarBinViewArray,
589        level: i32,
590        values_per_frame: usize,
591    ) -> VortexResult<Self> {
592        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
593    }
594
595    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
596    ///
597    /// This is useful when the compressed data will be decompressed by systems
598    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
599    ///
600    /// Note: Without a dictionary, each frame is compressed independently.
601    /// Dictionaries are trained from sample data from previously seen frames,
602    /// to improve compression ratio.
603    ///
604    /// # Arguments
605    /// * `vbv` - The VarBinView array to compress
606    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
607    /// * `values_per_frame` - Number of values per frame (0 = single frame)
608    pub fn from_var_bin_view_without_dict(
609        vbv: &VarBinViewArray,
610        level: i32,
611        values_per_frame: usize,
612    ) -> VortexResult<Self> {
613        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
614    }
615
616    fn from_var_bin_view_impl(
617        vbv: &VarBinViewArray,
618        level: i32,
619        values_per_frame: usize,
620        use_dictionary: bool,
621    ) -> VortexResult<Self> {
622        // Approach for strings: we prefix each string with its length as a u32.
623        // This is the same as what Parquet does. In some cases it may be better
624        // to separate the binary data and lengths as two separate streams, but
625        // this approach is simpler and can be best in cases when there is
626        // mutual information between strings and their lengths.
627        let dtype = vbv.dtype().clone();
628
629        // We compress only the valid elements.
630        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
631        let n_values = value_byte_indices.len();
632        let values_per_frame = if values_per_frame > 0 {
633            values_per_frame
634        } else {
635            n_values
636        };
637
638        let frame_byte_starts = (0..n_values)
639            .step_by(values_per_frame)
640            .map(|i| value_byte_indices[i])
641            .collect::<Vec<_>>();
642        let Frames {
643            dictionary,
644            frames,
645            frame_metas,
646        } = Self::compress_values(
647            &value_bytes,
648            &frame_byte_starts,
649            level,
650            values_per_frame,
651            n_values,
652            use_dictionary,
653        )?;
654
655        let metadata = ZstdMetadata {
656            dictionary_size: dictionary
657                .as_ref()
658                .map_or(0, |dict| dict.len())
659                .try_into()?,
660            frames: frame_metas,
661        };
662        Ok(ZstdArray::new(
663            dictionary,
664            frames,
665            dtype,
666            metadata,
667            vbv.len(),
668            vbv.validity().clone(),
669        ))
670    }
671
672    pub fn from_canonical(
673        canonical: &Canonical,
674        level: i32,
675        values_per_frame: usize,
676    ) -> VortexResult<Option<Self>> {
677        match canonical {
678            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
679                parray,
680                level,
681                values_per_frame,
682            )?)),
683            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
684                vbv,
685                level,
686                values_per_frame,
687            )?)),
688            _ => Ok(None),
689        }
690    }
691
692    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
693        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
694            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
695    }
696
697    fn byte_width(&self) -> usize {
698        if self.dtype.is_primitive() {
699            self.dtype.as_ptype().byte_width()
700        } else {
701            1
702        }
703    }
704
705    pub fn decompress(&self) -> VortexResult<ArrayRef> {
706        // To start, we figure out which frames we need to decompress, and with
707        // what row offset into the first such frame.
708        let byte_width = self.byte_width();
709        let slice_n_rows = self.slice_stop - self.slice_start;
710        let slice_value_indices = self
711            .unsliced_validity
712            .to_mask(self.unsliced_n_rows)
713            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
714
715        let slice_value_idx_start = slice_value_indices[0];
716        let slice_value_idx_stop = slice_value_indices[1];
717
718        let mut frames_to_decompress = vec![];
719        let mut value_idx_start = 0;
720        let mut uncompressed_size_to_decompress = 0;
721        let mut n_skipped_values = 0;
722        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
723            if value_idx_start >= slice_value_idx_stop {
724                break;
725            }
726
727            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
728                .vortex_expect("Uncompressed size must fit in usize");
729            let frame_n_values = if frame_meta.n_values == 0 {
730                // possibly older primitive-only metadata that just didn't store this
731                frame_uncompressed_size / byte_width
732            } else {
733                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
734            };
735
736            let value_idx_stop = value_idx_start + frame_n_values;
737            if value_idx_stop > slice_value_idx_start {
738                // we need this frame
739                frames_to_decompress.push(frame);
740                uncompressed_size_to_decompress += frame_uncompressed_size;
741            } else {
742                n_skipped_values += frame_n_values;
743            }
744            value_idx_start = value_idx_stop;
745        }
746
747        // then we actually decompress those frames
748        let mut decompressor = if let Some(dictionary) = &self.dictionary {
749            zstd::bulk::Decompressor::with_dictionary(dictionary)?
750        } else {
751            zstd::bulk::Decompressor::new()?
752        };
753        let mut decompressed = ByteBufferMut::with_capacity_aligned(
754            uncompressed_size_to_decompress,
755            Alignment::new(byte_width),
756        );
757        unsafe {
758            // safety: we immediately fill all bytes in the following loop,
759            // assuming our metadata's uncompressed size is correct
760            decompressed.set_len(uncompressed_size_to_decompress);
761        }
762        let mut uncompressed_start = 0;
763        for frame in frames_to_decompress {
764            let uncompressed_written = decompressor
765                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
766            uncompressed_start += uncompressed_written;
767        }
768        if uncompressed_start != uncompressed_size_to_decompress {
769            vortex_panic!(
770                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
771                uncompressed_size_to_decompress,
772                uncompressed_start
773            );
774        }
775
776        let decompressed = decompressed.freeze();
777        // Last, we slice the exact values requested out of the decompressed data.
778        let mut slice_validity = self
779            .unsliced_validity
780            .slice(self.slice_start..self.slice_stop)?;
781
782        // NOTE: this block handles setting the output type when the validity and DType disagree.
783        //
784        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
785        // the data frames. A ZSTD Array that was initialized must always hold onto its full
786        // validity bitmap, even if sliced to only include non-null values.
787        //
788        // We ensure that the validity of the decompressed array ALWAYS matches the validity
789        // implied by the DType.
790        if !self.dtype().is_nullable() && slice_validity != Validity::NonNullable {
791            assert!(
792                slice_validity.all_valid(slice_n_rows)?,
793                "ZSTD array expects to be non-nullable but there are nulls after decompression"
794            );
795
796            slice_validity = Validity::NonNullable;
797        } else if self.dtype.is_nullable() && slice_validity == Validity::NonNullable {
798            slice_validity = Validity::AllValid;
799        }
800        //
801        // END OF IMPORTANT BLOCK
802        //
803
804        match &self.dtype {
805            DType::Primitive(..) => {
806                let slice_values_buffer = decompressed.slice(
807                    (slice_value_idx_start - n_skipped_values) * byte_width
808                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
809                );
810                let primitive = PrimitiveArray::from_values_byte_buffer(
811                    slice_values_buffer,
812                    self.dtype.as_ptype(),
813                    slice_validity,
814                    slice_n_rows,
815                );
816
817                Ok(primitive.into_array())
818            }
819            DType::Binary(_) | DType::Utf8(_) => {
820                match slice_validity.to_mask(slice_n_rows).indices() {
821                    AllOr::All => {
822                        // the decompressed buffer is a bunch of interleaved u32 lengths
823                        // and strings of those lengths, we need to reconstruct the
824                        // views into those strings by passing through the buffer.
825                        let valid_views = reconstruct_views(&decompressed).slice(
826                            slice_value_idx_start - n_skipped_values
827                                ..slice_value_idx_stop - n_skipped_values,
828                        );
829
830                        // SAFETY: we properly construct the views inside `reconstruct_views`
831                        Ok(unsafe {
832                            VarBinViewArray::new_unchecked(
833                                valid_views,
834                                Arc::from([decompressed]),
835                                self.dtype.clone(),
836                                slice_validity,
837                            )
838                        }
839                        .into_array())
840                    }
841                    AllOr::None => Ok(ConstantArray::new(
842                        Scalar::null(self.dtype.clone()),
843                        slice_n_rows,
844                    )
845                    .into_array()),
846                    AllOr::Some(valid_indices) => {
847                        // the decompressed buffer is a bunch of interleaved u32 lengths
848                        // and strings of those lengths, we need to reconstruct the
849                        // views into those strings by passing through the buffer.
850                        let valid_views = reconstruct_views(&decompressed).slice(
851                            slice_value_idx_start - n_skipped_values
852                                ..slice_value_idx_stop - n_skipped_values,
853                        );
854
855                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
856                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
857                            views[*index] = view
858                        }
859
860                        // SAFETY: we properly construct the views inside `reconstruct_views`
861                        Ok(unsafe {
862                            VarBinViewArray::new_unchecked(
863                                views.freeze(),
864                                Arc::from([decompressed]),
865                                self.dtype.clone(),
866                                slice_validity,
867                            )
868                        }
869                        .into_array())
870                    }
871                }
872            }
873            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
874        }
875    }
876
877    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
878        let new_start = self.slice_start + start;
879        let new_stop = self.slice_start + stop;
880
881        assert!(
882            new_start <= self.slice_stop,
883            "new slice start {new_start} exceeds end {}",
884            self.slice_stop
885        );
886
887        assert!(
888            new_stop <= self.slice_stop,
889            "new slice stop {new_stop} exceeds end {}",
890            self.slice_stop
891        );
892
893        ZstdArray {
894            slice_start: self.slice_start + start,
895            slice_stop: self.slice_start + stop,
896            stats_set: Default::default(),
897            ..self.clone()
898        }
899    }
900
901    /// Consumes the array and returns its parts.
902    pub fn into_parts(self) -> ZstdArrayParts {
903        ZstdArrayParts {
904            dictionary: self.dictionary,
905            frames: self.frames,
906            metadata: self.metadata,
907            dtype: self.dtype,
908            validity: self.unsliced_validity,
909            n_rows: self.unsliced_n_rows,
910            slice_start: self.slice_start,
911            slice_stop: self.slice_stop,
912        }
913    }
914
915    pub(crate) fn dtype(&self) -> &DType {
916        &self.dtype
917    }
918
919    pub(crate) fn slice_start(&self) -> usize {
920        self.slice_start
921    }
922
923    pub(crate) fn slice_stop(&self) -> usize {
924        self.slice_stop
925    }
926
927    pub(crate) fn unsliced_n_rows(&self) -> usize {
928        self.unsliced_n_rows
929    }
930}
931
932impl ValiditySliceHelper for ZstdArray {
933    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
934        (&self.unsliced_validity, self.slice_start, self.slice_stop)
935    }
936}
937
938impl OperationsVTable<ZstdVTable> for ZstdVTable {
939    fn scalar_at(array: &ZstdArray, index: usize) -> VortexResult<Scalar> {
940        array._slice(index, index + 1).decompress()?.scalar_at(0)
941    }
942}