Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::Array;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayRef;
14use vortex_array::Canonical;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::ToCanonical;
20use vortex_array::accessor::ArrayAccessor;
21use vortex_array::arrays::ConstantArray;
22use vortex_array::arrays::PrimitiveArray;
23use vortex_array::arrays::VarBinViewArray;
24use vortex_array::arrays::build_views::BinaryView;
25use vortex_array::buffer::BufferHandle;
26use vortex_array::dtype::DType;
27use vortex_array::scalar::Scalar;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::stats::ArrayStats;
30use vortex_array::stats::StatsSetRef;
31use vortex_array::validity::Validity;
32use vortex_array::vtable;
33use vortex_array::vtable::ArrayId;
34use vortex_array::vtable::OperationsVTable;
35use vortex_array::vtable::VTable;
36use vortex_array::vtable::ValidityHelper;
37use vortex_array::vtable::ValiditySliceHelper;
38use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
39use vortex_array::vtable::validity_nchildren;
40use vortex_array::vtable::validity_to_child;
41use vortex_buffer::Alignment;
42use vortex_buffer::Buffer;
43use vortex_buffer::BufferMut;
44use vortex_buffer::ByteBuffer;
45use vortex_buffer::ByteBufferMut;
46use vortex_error::VortexError;
47use vortex_error::VortexExpect;
48use vortex_error::VortexResult;
49use vortex_error::vortex_bail;
50use vortex_error::vortex_ensure;
51use vortex_error::vortex_err;
52use vortex_error::vortex_panic;
53use vortex_mask::AllOr;
54use vortex_session::VortexSession;
55
56use crate::ZstdFrameMetadata;
57use crate::ZstdMetadata;
58
59// Zstd doesn't support training dictionaries on very few samples.
60const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
61type ViewLen = u32;
62
63// Overall approach here:
64// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
65// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
66// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
67// want somewhat faster access to slices or individual rows, allowing us to only
68// decompress the necessary frames.
69
70// Visually, during decompression, we have an interval of frames we're
71// decompressing and a tighter interval of the slice we actually care about.
72// |=============values (all valid elements)==============|
73// |<-skipped_uncompressed->|----decompressed-------------|
74//                              |------slice-------|
75//                              ^                  ^
76// |<-slice_uncompressed_start->|                  |
77// |<------------slice_uncompressed_stop---------->|
78// We then insert these values to the correct position using a primitive array
79// constructor.
80
81vtable!(Zstd);
82
83impl VTable for ZstdVTable {
84    type Array = ZstdArray;
85
86    type Metadata = ProstMetadata<ZstdMetadata>;
87    type OperationsVTable = Self;
88    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
89
90    fn id(_array: &Self::Array) -> ArrayId {
91        Self::ID
92    }
93
94    fn len(array: &ZstdArray) -> usize {
95        array.slice_stop - array.slice_start
96    }
97
98    fn dtype(array: &ZstdArray) -> &DType {
99        &array.dtype
100    }
101
102    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
103        array.stats_set.to_ref(array.as_ref())
104    }
105
106    fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
107        match &array.dictionary {
108            Some(dict) => {
109                true.hash(state);
110                dict.array_hash(state, precision);
111            }
112            None => {
113                false.hash(state);
114            }
115        }
116        for frame in &array.frames {
117            frame.array_hash(state, precision);
118        }
119        array.dtype.hash(state);
120        array.unsliced_validity.array_hash(state, precision);
121        array.unsliced_n_rows.hash(state);
122        array.slice_start.hash(state);
123        array.slice_stop.hash(state);
124    }
125
126    fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
127        if !match (&array.dictionary, &other.dictionary) {
128            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
129            (None, None) => true,
130            _ => false,
131        } {
132            return false;
133        }
134        if array.frames.len() != other.frames.len() {
135            return false;
136        }
137        for (a, b) in array.frames.iter().zip(&other.frames) {
138            if !a.array_eq(b, precision) {
139                return false;
140            }
141        }
142        array.dtype == other.dtype
143            && array
144                .unsliced_validity
145                .array_eq(&other.unsliced_validity, precision)
146            && array.unsliced_n_rows == other.unsliced_n_rows
147            && array.slice_start == other.slice_start
148            && array.slice_stop == other.slice_stop
149    }
150
151    fn nbuffers(array: &ZstdArray) -> usize {
152        array.dictionary.is_some() as usize + array.frames.len()
153    }
154
155    fn buffer(array: &ZstdArray, idx: usize) -> BufferHandle {
156        if let Some(dict) = &array.dictionary {
157            if idx == 0 {
158                return BufferHandle::new_host(dict.clone());
159            }
160            BufferHandle::new_host(array.frames[idx - 1].clone())
161        } else {
162            BufferHandle::new_host(array.frames[idx].clone())
163        }
164    }
165
166    fn buffer_name(array: &ZstdArray, idx: usize) -> Option<String> {
167        if array.dictionary.is_some() {
168            if idx == 0 {
169                Some("dictionary".to_string())
170            } else {
171                Some(format!("frame_{}", idx - 1))
172            }
173        } else {
174            Some(format!("frame_{idx}"))
175        }
176    }
177
178    fn nchildren(array: &ZstdArray) -> usize {
179        validity_nchildren(&array.unsliced_validity)
180    }
181
182    fn child(array: &ZstdArray, idx: usize) -> ArrayRef {
183        validity_to_child(&array.unsliced_validity, array.unsliced_n_rows)
184            .unwrap_or_else(|| vortex_panic!("ZstdArray child index {idx} out of bounds"))
185    }
186
187    fn child_name(_array: &ZstdArray, idx: usize) -> String {
188        match idx {
189            0 => "validity".to_string(),
190            _ => vortex_panic!("ZstdArray child_name index {idx} out of bounds"),
191        }
192    }
193
194    fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
195        Ok(ProstMetadata(array.metadata.clone()))
196    }
197
198    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
199        Ok(Some(metadata.0.encode_to_vec()))
200    }
201
202    fn deserialize(
203        bytes: &[u8],
204        _dtype: &DType,
205        _len: usize,
206        _buffers: &[BufferHandle],
207        _session: &VortexSession,
208    ) -> VortexResult<Self::Metadata> {
209        Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
210    }
211
212    fn build(
213        dtype: &DType,
214        len: usize,
215        metadata: &Self::Metadata,
216        buffers: &[BufferHandle],
217        children: &dyn ArrayChildren,
218    ) -> VortexResult<ZstdArray> {
219        let validity = if children.is_empty() {
220            Validity::from(dtype.nullability())
221        } else if children.len() == 1 {
222            let validity = children.get(0, &Validity::DTYPE, len)?;
223            Validity::Array(validity)
224        } else {
225            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
226        };
227
228        let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
229            // no dictionary
230            (
231                None,
232                buffers
233                    .iter()
234                    .map(|b| b.clone().try_to_host_sync())
235                    .collect::<VortexResult<Vec<_>>>()?,
236            )
237        } else {
238            // with dictionary
239            (
240                Some(buffers[0].clone().try_to_host_sync()?),
241                buffers[1..]
242                    .iter()
243                    .map(|b| b.clone().try_to_host_sync())
244                    .collect::<VortexResult<Vec<_>>>()?,
245            )
246        };
247
248        Ok(ZstdArray::new(
249            dictionary_buffer,
250            compressed_buffers,
251            dtype.clone(),
252            metadata.0.clone(),
253            len,
254            validity,
255        ))
256    }
257
258    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
259        vortex_ensure!(
260            children.len() <= 1,
261            "ZstdArray expects at most 1 child (validity), got {}",
262            children.len()
263        );
264
265        array.unsliced_validity = if children.is_empty() {
266            Validity::from(array.dtype.nullability())
267        } else {
268            Validity::Array(children.into_iter().next().vortex_expect("checked"))
269        };
270
271        Ok(())
272    }
273
274    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
275        array.decompress()?.execute(ctx)
276    }
277
278    fn reduce_parent(
279        array: &Self::Array,
280        parent: &ArrayRef,
281        child_idx: usize,
282    ) -> VortexResult<Option<ArrayRef>> {
283        crate::rules::RULES.evaluate(array, parent, child_idx)
284    }
285}
286
287#[derive(Debug)]
288pub struct ZstdVTable;
289
290impl ZstdVTable {
291    pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
292}
293
294#[derive(Clone, Debug)]
295pub struct ZstdArray {
296    pub(crate) dictionary: Option<ByteBuffer>,
297    pub(crate) frames: Vec<ByteBuffer>,
298    pub(crate) metadata: ZstdMetadata,
299    dtype: DType,
300    pub(crate) unsliced_validity: Validity,
301    unsliced_n_rows: usize,
302    stats_set: ArrayStats,
303    slice_start: usize,
304    slice_stop: usize,
305}
306
307/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
308#[derive(Debug)]
309pub struct ZstdArrayParts {
310    /// The optional dictionary used for compression.
311    pub dictionary: Option<ByteBuffer>,
312    /// The compressed frames.
313    pub frames: Vec<ByteBuffer>,
314    /// The compression metadata.
315    pub metadata: ZstdMetadata,
316    /// The data type of the uncompressed array.
317    pub dtype: DType,
318    /// The validity of the uncompressed array.
319    pub validity: Validity,
320    /// The number of rows in the uncompressed array.
321    pub n_rows: usize,
322    /// Slice start offset.
323    pub slice_start: usize,
324    /// Slice stop offset.
325    pub slice_stop: usize,
326}
327
328struct Frames {
329    dictionary: Option<ByteBuffer>,
330    frames: Vec<ByteBuffer>,
331    frame_metas: Vec<ZstdFrameMetadata>,
332}
333
334fn choose_max_dict_size(uncompressed_size: usize) -> usize {
335    // following recommendations from
336    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
337    // that is, 1/100 the data size, up to 100kB.
338    // It appears that zstd can't train dictionaries with <256 bytes.
339    (uncompressed_size / 100).clamp(256, 100 * 1024)
340}
341
342fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
343    let mask = parray.validity_mask()?;
344    Ok(parray.to_array().filter(mask)?.to_primitive())
345}
346
347fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
348    let mask = vbv.validity_mask()?;
349    let buffer_and_value_byte_indices = match mask.bit_buffer() {
350        AllOr::None => (Buffer::empty(), Vec::new()),
351        _ => {
352            let mut buffer = BufferMut::with_capacity(
353                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
354                    + mask.true_count() * size_of::<ViewLen>(),
355            );
356            let mut value_byte_indices = Vec::new();
357            vbv.with_iterator(|iterator| {
358                // by flattening, we should omit nulls
359                for value in iterator.flatten() {
360                    value_byte_indices.push(buffer.len());
361                    // here's where we write the string lengths
362                    buffer
363                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
364                    buffer.extend_from_slice(value);
365                }
366                Ok::<_, VortexError>(())
367            })?;
368            (buffer.freeze(), value_byte_indices)
369        }
370    };
371    Ok(buffer_and_value_byte_indices)
372}
373
374/// Reconstruct BinaryView structs from length-prefixed byte data.
375///
376/// The buffer contains interleaved u32 lengths (little-endian) and string data.
377pub fn reconstruct_views(buffer: &ByteBuffer) -> Buffer<BinaryView> {
378    let mut res = BufferMut::<BinaryView>::empty();
379    let mut offset = 0;
380    while offset < buffer.len() {
381        let str_len = ViewLen::from_le_bytes(
382            buffer
383                .get(offset..offset + size_of::<ViewLen>())
384                .vortex_expect("corrupted zstd length")
385                .try_into()
386                .ok()
387                .vortex_expect("must fit ViewLen size"),
388        ) as usize;
389        offset += size_of::<ViewLen>();
390        let value = &buffer[offset..offset + str_len];
391        res.push(BinaryView::make_view(
392            value,
393            0,
394            u32::try_from(offset).vortex_expect("offset must fit in u32"),
395        ));
396        offset += str_len;
397    }
398    res.freeze()
399}
400
401impl ZstdArray {
402    pub fn new(
403        dictionary: Option<ByteBuffer>,
404        frames: Vec<ByteBuffer>,
405        dtype: DType,
406        metadata: ZstdMetadata,
407        n_rows: usize,
408        validity: Validity,
409    ) -> Self {
410        Self {
411            dictionary,
412            frames,
413            metadata,
414            dtype,
415            unsliced_validity: validity,
416            unsliced_n_rows: n_rows,
417            stats_set: Default::default(),
418            slice_start: 0,
419            slice_stop: n_rows,
420        }
421    }
422
423    fn compress_values(
424        value_bytes: &ByteBuffer,
425        frame_byte_starts: &[usize],
426        level: i32,
427        values_per_frame: usize,
428        n_values: usize,
429        use_dictionary: bool,
430    ) -> VortexResult<Frames> {
431        let n_frames = frame_byte_starts.len();
432
433        // Would-be sample sizes if we end up applying zstd dictionary
434        let mut sample_sizes = Vec::with_capacity(n_frames);
435        for i in 0..n_frames {
436            let frame_byte_end = frame_byte_starts
437                .get(i + 1)
438                .copied()
439                .unwrap_or(value_bytes.len());
440            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
441        }
442        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
443
444        let (dictionary, mut compressor) = if !use_dictionary
445            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
446        {
447            // no dictionary
448            (None, zstd::bulk::Compressor::new(level)?)
449        } else {
450            // with dictionary
451            let max_dict_size = choose_max_dict_size(value_bytes.len());
452            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
453                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
454
455            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
456            (Some(ByteBuffer::from(dict)), compressor)
457        };
458
459        let mut frame_metas = vec![];
460        let mut frames = vec![];
461        for i in 0..n_frames {
462            let frame_byte_end = frame_byte_starts
463                .get(i + 1)
464                .copied()
465                .unwrap_or(value_bytes.len());
466
467            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
468            let compressed = compressor
469                .compress(uncompressed)
470                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
471            frame_metas.push(ZstdFrameMetadata {
472                uncompressed_size: uncompressed.len() as u64,
473                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
474            });
475            frames.push(ByteBuffer::from(compressed));
476        }
477
478        Ok(Frames {
479            dictionary,
480            frames,
481            frame_metas,
482        })
483    }
484
485    /// Creates a ZstdArray from a primitive array.
486    ///
487    /// # Arguments
488    /// * `parray` - The primitive array to compress
489    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
490    /// * `values_per_frame` - Number of values per frame (0 = single frame)
491    pub fn from_primitive(
492        parray: &PrimitiveArray,
493        level: i32,
494        values_per_frame: usize,
495    ) -> VortexResult<Self> {
496        Self::from_primitive_impl(parray, level, values_per_frame, true)
497    }
498
499    /// Creates a ZstdArray from a primitive array without using a dictionary.
500    ///
501    /// This is useful when the compressed data will be decompressed by systems
502    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
503    ///
504    /// Note: Without a dictionary, each frame is compressed independently.
505    /// Dictionaries are trained from sample data from previously seen frames,
506    /// to improve compression ratio.
507    ///
508    /// # Arguments
509    /// * `parray` - The primitive array to compress
510    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
511    /// * `values_per_frame` - Number of values per frame (0 = single frame)
512    pub fn from_primitive_without_dict(
513        parray: &PrimitiveArray,
514        level: i32,
515        values_per_frame: usize,
516    ) -> VortexResult<Self> {
517        Self::from_primitive_impl(parray, level, values_per_frame, false)
518    }
519
520    fn from_primitive_impl(
521        parray: &PrimitiveArray,
522        level: i32,
523        values_per_frame: usize,
524        use_dictionary: bool,
525    ) -> VortexResult<Self> {
526        let dtype = parray.dtype().clone();
527        let byte_width = parray.ptype().byte_width();
528
529        // We compress only the valid elements.
530        let values = collect_valid_primitive(parray)?;
531        let n_values = values.len();
532        let values_per_frame = if values_per_frame > 0 {
533            values_per_frame
534        } else {
535            n_values
536        };
537
538        let value_bytes = values.buffer_handle().try_to_host_sync()?;
539        // Align frames to buffer alignment. This is necessary for overaligned buffers.
540        let alignment = *value_bytes.alignment();
541        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
542
543        let frame_byte_starts = (0..n_values * byte_width)
544            .step_by(step_width)
545            .collect::<Vec<_>>();
546        let Frames {
547            dictionary,
548            frames,
549            frame_metas,
550        } = Self::compress_values(
551            &value_bytes,
552            &frame_byte_starts,
553            level,
554            values_per_frame,
555            n_values,
556            use_dictionary,
557        )?;
558
559        let metadata = ZstdMetadata {
560            dictionary_size: dictionary
561                .as_ref()
562                .map_or(0, |dict| dict.len())
563                .try_into()?,
564            frames: frame_metas,
565        };
566
567        Ok(ZstdArray::new(
568            dictionary,
569            frames,
570            dtype,
571            metadata,
572            parray.len(),
573            parray.validity().clone(),
574        ))
575    }
576
577    /// Creates a ZstdArray from a VarBinView array.
578    ///
579    /// # Arguments
580    /// * `vbv` - The VarBinView array to compress
581    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
582    /// * `values_per_frame` - Number of values per frame (0 = single frame)
583    pub fn from_var_bin_view(
584        vbv: &VarBinViewArray,
585        level: i32,
586        values_per_frame: usize,
587    ) -> VortexResult<Self> {
588        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
589    }
590
591    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
592    ///
593    /// This is useful when the compressed data will be decompressed by systems
594    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
595    ///
596    /// Note: Without a dictionary, each frame is compressed independently.
597    /// Dictionaries are trained from sample data from previously seen frames,
598    /// to improve compression ratio.
599    ///
600    /// # Arguments
601    /// * `vbv` - The VarBinView array to compress
602    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
603    /// * `values_per_frame` - Number of values per frame (0 = single frame)
604    pub fn from_var_bin_view_without_dict(
605        vbv: &VarBinViewArray,
606        level: i32,
607        values_per_frame: usize,
608    ) -> VortexResult<Self> {
609        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
610    }
611
612    fn from_var_bin_view_impl(
613        vbv: &VarBinViewArray,
614        level: i32,
615        values_per_frame: usize,
616        use_dictionary: bool,
617    ) -> VortexResult<Self> {
618        // Approach for strings: we prefix each string with its length as a u32.
619        // This is the same as what Parquet does. In some cases it may be better
620        // to separate the binary data and lengths as two separate streams, but
621        // this approach is simpler and can be best in cases when there is
622        // mutual information between strings and their lengths.
623        let dtype = vbv.dtype().clone();
624
625        // We compress only the valid elements.
626        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
627        let n_values = value_byte_indices.len();
628        let values_per_frame = if values_per_frame > 0 {
629            values_per_frame
630        } else {
631            n_values
632        };
633
634        let frame_byte_starts = (0..n_values)
635            .step_by(values_per_frame)
636            .map(|i| value_byte_indices[i])
637            .collect::<Vec<_>>();
638        let Frames {
639            dictionary,
640            frames,
641            frame_metas,
642        } = Self::compress_values(
643            &value_bytes,
644            &frame_byte_starts,
645            level,
646            values_per_frame,
647            n_values,
648            use_dictionary,
649        )?;
650
651        let metadata = ZstdMetadata {
652            dictionary_size: dictionary
653                .as_ref()
654                .map_or(0, |dict| dict.len())
655                .try_into()?,
656            frames: frame_metas,
657        };
658        Ok(ZstdArray::new(
659            dictionary,
660            frames,
661            dtype,
662            metadata,
663            vbv.len(),
664            vbv.validity().clone(),
665        ))
666    }
667
668    pub fn from_canonical(
669        canonical: &Canonical,
670        level: i32,
671        values_per_frame: usize,
672    ) -> VortexResult<Option<Self>> {
673        match canonical {
674            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
675                parray,
676                level,
677                values_per_frame,
678            )?)),
679            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
680                vbv,
681                level,
682                values_per_frame,
683            )?)),
684            _ => Ok(None),
685        }
686    }
687
688    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
689        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
690            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
691    }
692
693    fn byte_width(&self) -> usize {
694        if self.dtype.is_primitive() {
695            self.dtype.as_ptype().byte_width()
696        } else {
697            1
698        }
699    }
700
701    pub fn decompress(&self) -> VortexResult<ArrayRef> {
702        // To start, we figure out which frames we need to decompress, and with
703        // what row offset into the first such frame.
704        let byte_width = self.byte_width();
705        let slice_n_rows = self.slice_stop - self.slice_start;
706        let slice_value_indices = self
707            .unsliced_validity
708            .to_mask(self.unsliced_n_rows)
709            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
710
711        let slice_value_idx_start = slice_value_indices[0];
712        let slice_value_idx_stop = slice_value_indices[1];
713
714        let mut frames_to_decompress = vec![];
715        let mut value_idx_start = 0;
716        let mut uncompressed_size_to_decompress = 0;
717        let mut n_skipped_values = 0;
718        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
719            if value_idx_start >= slice_value_idx_stop {
720                break;
721            }
722
723            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
724                .vortex_expect("Uncompressed size must fit in usize");
725            let frame_n_values = if frame_meta.n_values == 0 {
726                // possibly older primitive-only metadata that just didn't store this
727                frame_uncompressed_size / byte_width
728            } else {
729                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
730            };
731
732            let value_idx_stop = value_idx_start + frame_n_values;
733            if value_idx_stop > slice_value_idx_start {
734                // we need this frame
735                frames_to_decompress.push(frame);
736                uncompressed_size_to_decompress += frame_uncompressed_size;
737            } else {
738                n_skipped_values += frame_n_values;
739            }
740            value_idx_start = value_idx_stop;
741        }
742
743        // then we actually decompress those frames
744        let mut decompressor = if let Some(dictionary) = &self.dictionary {
745            zstd::bulk::Decompressor::with_dictionary(dictionary)?
746        } else {
747            zstd::bulk::Decompressor::new()?
748        };
749        let mut decompressed = ByteBufferMut::with_capacity_aligned(
750            uncompressed_size_to_decompress,
751            Alignment::new(byte_width),
752        );
753        unsafe {
754            // safety: we immediately fill all bytes in the following loop,
755            // assuming our metadata's uncompressed size is correct
756            decompressed.set_len(uncompressed_size_to_decompress);
757        }
758        let mut uncompressed_start = 0;
759        for frame in frames_to_decompress {
760            let uncompressed_written = decompressor
761                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
762            uncompressed_start += uncompressed_written;
763        }
764        if uncompressed_start != uncompressed_size_to_decompress {
765            vortex_panic!(
766                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
767                uncompressed_size_to_decompress,
768                uncompressed_start
769            );
770        }
771
772        let decompressed = decompressed.freeze();
773        // Last, we slice the exact values requested out of the decompressed data.
774        let mut slice_validity = self
775            .unsliced_validity
776            .slice(self.slice_start..self.slice_stop)?;
777
778        // NOTE: this block handles setting the output type when the validity and DType disagree.
779        //
780        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
781        // the data frames. A ZSTD Array that was initialized must always hold onto its full
782        // validity bitmap, even if sliced to only include non-null values.
783        //
784        // We ensure that the validity of the decompressed array ALWAYS matches the validity
785        // implied by the DType.
786        if !self.dtype().is_nullable() && slice_validity != Validity::NonNullable {
787            assert!(
788                slice_validity.all_valid(slice_n_rows)?,
789                "ZSTD array expects to be non-nullable but there are nulls after decompression"
790            );
791
792            slice_validity = Validity::NonNullable;
793        } else if self.dtype.is_nullable() && slice_validity == Validity::NonNullable {
794            slice_validity = Validity::AllValid;
795        }
796        //
797        // END OF IMPORTANT BLOCK
798        //
799
800        match &self.dtype {
801            DType::Primitive(..) => {
802                let slice_values_buffer = decompressed.slice(
803                    (slice_value_idx_start - n_skipped_values) * byte_width
804                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
805                );
806                let primitive = PrimitiveArray::from_values_byte_buffer(
807                    slice_values_buffer,
808                    self.dtype.as_ptype(),
809                    slice_validity,
810                    slice_n_rows,
811                );
812
813                Ok(primitive.into_array())
814            }
815            DType::Binary(_) | DType::Utf8(_) => {
816                match slice_validity.to_mask(slice_n_rows).indices() {
817                    AllOr::All => {
818                        // the decompressed buffer is a bunch of interleaved u32 lengths
819                        // and strings of those lengths, we need to reconstruct the
820                        // views into those strings by passing through the buffer.
821                        let valid_views = reconstruct_views(&decompressed).slice(
822                            slice_value_idx_start - n_skipped_values
823                                ..slice_value_idx_stop - n_skipped_values,
824                        );
825
826                        // SAFETY: we properly construct the views inside `reconstruct_views`
827                        Ok(unsafe {
828                            VarBinViewArray::new_unchecked(
829                                valid_views,
830                                Arc::from([decompressed]),
831                                self.dtype.clone(),
832                                slice_validity,
833                            )
834                        }
835                        .into_array())
836                    }
837                    AllOr::None => Ok(ConstantArray::new(
838                        Scalar::null(self.dtype.clone()),
839                        slice_n_rows,
840                    )
841                    .into_array()),
842                    AllOr::Some(valid_indices) => {
843                        // the decompressed buffer is a bunch of interleaved u32 lengths
844                        // and strings of those lengths, we need to reconstruct the
845                        // views into those strings by passing through the buffer.
846                        let valid_views = reconstruct_views(&decompressed).slice(
847                            slice_value_idx_start - n_skipped_values
848                                ..slice_value_idx_stop - n_skipped_values,
849                        );
850
851                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
852                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
853                            views[*index] = view
854                        }
855
856                        // SAFETY: we properly construct the views inside `reconstruct_views`
857                        Ok(unsafe {
858                            VarBinViewArray::new_unchecked(
859                                views.freeze(),
860                                Arc::from([decompressed]),
861                                self.dtype.clone(),
862                                slice_validity,
863                            )
864                        }
865                        .into_array())
866                    }
867                }
868            }
869            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
870        }
871    }
872
873    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
874        let new_start = self.slice_start + start;
875        let new_stop = self.slice_start + stop;
876
877        assert!(
878            new_start <= self.slice_stop,
879            "new slice start {new_start} exceeds end {}",
880            self.slice_stop
881        );
882
883        assert!(
884            new_stop <= self.slice_stop,
885            "new slice stop {new_stop} exceeds end {}",
886            self.slice_stop
887        );
888
889        ZstdArray {
890            slice_start: self.slice_start + start,
891            slice_stop: self.slice_start + stop,
892            stats_set: Default::default(),
893            ..self.clone()
894        }
895    }
896
897    /// Consumes the array and returns its parts.
898    pub fn into_parts(self) -> ZstdArrayParts {
899        ZstdArrayParts {
900            dictionary: self.dictionary,
901            frames: self.frames,
902            metadata: self.metadata,
903            dtype: self.dtype,
904            validity: self.unsliced_validity,
905            n_rows: self.unsliced_n_rows,
906            slice_start: self.slice_start,
907            slice_stop: self.slice_stop,
908        }
909    }
910
911    pub(crate) fn dtype(&self) -> &DType {
912        &self.dtype
913    }
914
915    pub(crate) fn slice_start(&self) -> usize {
916        self.slice_start
917    }
918
919    pub(crate) fn slice_stop(&self) -> usize {
920        self.slice_stop
921    }
922
923    pub(crate) fn unsliced_n_rows(&self) -> usize {
924        self.unsliced_n_rows
925    }
926}
927
928impl ValiditySliceHelper for ZstdArray {
929    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
930        (&self.unsliced_validity, self.slice_start, self.slice_stop)
931    }
932}
933
934impl OperationsVTable<ZstdVTable> for ZstdVTable {
935    fn scalar_at(array: &ZstdArray, index: usize) -> VortexResult<Scalar> {
936        array._slice(index, index + 1).decompress()?.scalar_at(0)
937    }
938}