Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::ArrayBufferVisitor;
11use vortex_array::ArrayChildVisitor;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayRef;
15use vortex_array::Canonical;
16use vortex_array::ExecutionCtx;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::ProstMetadata;
20use vortex_array::ToCanonical;
21use vortex_array::accessor::ArrayAccessor;
22use vortex_array::arrays::ConstantArray;
23use vortex_array::arrays::PrimitiveArray;
24use vortex_array::arrays::VarBinViewArray;
25use vortex_array::arrays::build_views::BinaryView;
26use vortex_array::buffer::BufferHandle;
27use vortex_array::compute::filter;
28use vortex_array::scalar::Scalar;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::BaseArrayVTable;
36use vortex_array::vtable::OperationsVTable;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityHelper;
39use vortex_array::vtable::ValiditySliceHelper;
40use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
41use vortex_array::vtable::VisitorVTable;
42use vortex_array::vtable::validity_nchildren;
43use vortex_buffer::Alignment;
44use vortex_buffer::Buffer;
45use vortex_buffer::BufferMut;
46use vortex_buffer::ByteBuffer;
47use vortex_buffer::ByteBufferMut;
48use vortex_dtype::DType;
49use vortex_error::VortexError;
50use vortex_error::VortexExpect;
51use vortex_error::VortexResult;
52use vortex_error::vortex_bail;
53use vortex_error::vortex_ensure;
54use vortex_error::vortex_err;
55use vortex_error::vortex_panic;
56use vortex_mask::AllOr;
57use vortex_session::VortexSession;
58
59use crate::ZstdFrameMetadata;
60use crate::ZstdMetadata;
61
62// Zstd doesn't support training dictionaries on very few samples.
63const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
64type ViewLen = u32;
65
66// Overall approach here:
67// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
68// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
69// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
70// want somewhat faster access to slices or individual rows, allowing us to only
71// decompress the necessary frames.
72
73// Visually, during decompression, we have an interval of frames we're
74// decompressing and a tighter interval of the slice we actually care about.
75// |=============values (all valid elements)==============|
76// |<-skipped_uncompressed->|----decompressed-------------|
77//                              |------slice-------|
78//                              ^                  ^
79// |<-slice_uncompressed_start->|                  |
80// |<------------slice_uncompressed_stop---------->|
81// We then insert these values to the correct position using a primitive array
82// constructor.
83
84vtable!(Zstd);
85
86impl VTable for ZstdVTable {
87    type Array = ZstdArray;
88
89    type Metadata = ProstMetadata<ZstdMetadata>;
90
91    type ArrayVTable = Self;
92    type OperationsVTable = Self;
93    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
94    type VisitorVTable = Self;
95
96    fn id(_array: &Self::Array) -> ArrayId {
97        Self::ID
98    }
99
100    fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
101        Ok(ProstMetadata(array.metadata.clone()))
102    }
103
104    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
105        Ok(Some(metadata.0.encode_to_vec()))
106    }
107
108    fn deserialize(
109        bytes: &[u8],
110        _dtype: &DType,
111        _len: usize,
112        _buffers: &[BufferHandle],
113        _session: &VortexSession,
114    ) -> VortexResult<Self::Metadata> {
115        Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
116    }
117
118    fn build(
119        dtype: &DType,
120        len: usize,
121        metadata: &Self::Metadata,
122        buffers: &[BufferHandle],
123        children: &dyn ArrayChildren,
124    ) -> VortexResult<ZstdArray> {
125        let validity = if children.is_empty() {
126            Validity::from(dtype.nullability())
127        } else if children.len() == 1 {
128            let validity = children.get(0, &Validity::DTYPE, len)?;
129            Validity::Array(validity)
130        } else {
131            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
132        };
133
134        let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
135            // no dictionary
136            (
137                None,
138                buffers
139                    .iter()
140                    .map(|b| b.clone().try_to_host_sync())
141                    .collect::<VortexResult<Vec<_>>>()?,
142            )
143        } else {
144            // with dictionary
145            (
146                Some(buffers[0].clone().try_to_host_sync()?),
147                buffers[1..]
148                    .iter()
149                    .map(|b| b.clone().try_to_host_sync())
150                    .collect::<VortexResult<Vec<_>>>()?,
151            )
152        };
153
154        Ok(ZstdArray::new(
155            dictionary_buffer,
156            compressed_buffers,
157            dtype.clone(),
158            metadata.0.clone(),
159            len,
160            validity,
161        ))
162    }
163
164    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
165        vortex_ensure!(
166            children.len() <= 1,
167            "ZstdArray expects at most 1 child (validity), got {}",
168            children.len()
169        );
170
171        array.unsliced_validity = if children.is_empty() {
172            Validity::from(array.dtype.nullability())
173        } else {
174            Validity::Array(children.into_iter().next().vortex_expect("checked"))
175        };
176
177        Ok(())
178    }
179
180    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
181        array.decompress()?.execute(ctx)
182    }
183
184    fn reduce_parent(
185        array: &Self::Array,
186        parent: &ArrayRef,
187        child_idx: usize,
188    ) -> VortexResult<Option<ArrayRef>> {
189        crate::rules::RULES.evaluate(array, parent, child_idx)
190    }
191}
192
193#[derive(Debug)]
194pub struct ZstdVTable;
195
196impl ZstdVTable {
197    pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
198}
199
200#[derive(Clone, Debug)]
201pub struct ZstdArray {
202    pub(crate) dictionary: Option<ByteBuffer>,
203    pub(crate) frames: Vec<ByteBuffer>,
204    pub(crate) metadata: ZstdMetadata,
205    dtype: DType,
206    pub(crate) unsliced_validity: Validity,
207    unsliced_n_rows: usize,
208    stats_set: ArrayStats,
209    slice_start: usize,
210    slice_stop: usize,
211}
212
213/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
214#[derive(Debug)]
215pub struct ZstdArrayParts {
216    /// The optional dictionary used for compression.
217    pub dictionary: Option<ByteBuffer>,
218    /// The compressed frames.
219    pub frames: Vec<ByteBuffer>,
220    /// The compression metadata.
221    pub metadata: ZstdMetadata,
222    /// The data type of the uncompressed array.
223    pub dtype: DType,
224    /// The validity of the uncompressed array.
225    pub validity: Validity,
226    /// The number of rows in the uncompressed array.
227    pub n_rows: usize,
228    /// Slice start offset.
229    pub slice_start: usize,
230    /// Slice stop offset.
231    pub slice_stop: usize,
232}
233
234struct Frames {
235    dictionary: Option<ByteBuffer>,
236    frames: Vec<ByteBuffer>,
237    frame_metas: Vec<ZstdFrameMetadata>,
238}
239
240fn choose_max_dict_size(uncompressed_size: usize) -> usize {
241    // following recommendations from
242    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
243    // that is, 1/100 the data size, up to 100kB.
244    // It appears that zstd can't train dictionaries with <256 bytes.
245    (uncompressed_size / 100).clamp(256, 100 * 1024)
246}
247
248fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
249    let mask = parray.validity_mask()?;
250    Ok(filter(&parray.to_array(), &mask)?.to_primitive())
251}
252
253fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
254    let mask = vbv.validity_mask()?;
255    let buffer_and_value_byte_indices = match mask.bit_buffer() {
256        AllOr::None => (Buffer::empty(), Vec::new()),
257        _ => {
258            let mut buffer = BufferMut::with_capacity(
259                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
260                    + mask.true_count() * size_of::<ViewLen>(),
261            );
262            let mut value_byte_indices = Vec::new();
263            vbv.with_iterator(|iterator| {
264                // by flattening, we should omit nulls
265                for value in iterator.flatten() {
266                    value_byte_indices.push(buffer.len());
267                    // here's where we write the string lengths
268                    buffer
269                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
270                    buffer.extend_from_slice(value);
271                }
272                Ok::<_, VortexError>(())
273            })?;
274            (buffer.freeze(), value_byte_indices)
275        }
276    };
277    Ok(buffer_and_value_byte_indices)
278}
279
280/// Reconstruct BinaryView structs from length-prefixed byte data.
281///
282/// The buffer contains interleaved u32 lengths (little-endian) and string data.
283pub fn reconstruct_views(buffer: &ByteBuffer) -> Buffer<BinaryView> {
284    let mut res = BufferMut::<BinaryView>::empty();
285    let mut offset = 0;
286    while offset < buffer.len() {
287        let str_len = ViewLen::from_le_bytes(
288            buffer
289                .get(offset..offset + size_of::<ViewLen>())
290                .vortex_expect("corrupted zstd length")
291                .try_into()
292                .ok()
293                .vortex_expect("must fit ViewLen size"),
294        ) as usize;
295        offset += size_of::<ViewLen>();
296        let value = &buffer[offset..offset + str_len];
297        res.push(BinaryView::make_view(
298            value,
299            0,
300            u32::try_from(offset).vortex_expect("offset must fit in u32"),
301        ));
302        offset += str_len;
303    }
304    res.freeze()
305}
306
307impl ZstdArray {
308    pub fn new(
309        dictionary: Option<ByteBuffer>,
310        frames: Vec<ByteBuffer>,
311        dtype: DType,
312        metadata: ZstdMetadata,
313        n_rows: usize,
314        validity: Validity,
315    ) -> Self {
316        Self {
317            dictionary,
318            frames,
319            metadata,
320            dtype,
321            unsliced_validity: validity,
322            unsliced_n_rows: n_rows,
323            stats_set: Default::default(),
324            slice_start: 0,
325            slice_stop: n_rows,
326        }
327    }
328
329    fn compress_values(
330        value_bytes: &ByteBuffer,
331        frame_byte_starts: &[usize],
332        level: i32,
333        values_per_frame: usize,
334        n_values: usize,
335        use_dictionary: bool,
336    ) -> VortexResult<Frames> {
337        let n_frames = frame_byte_starts.len();
338
339        // Would-be sample sizes if we end up applying zstd dictionary
340        let mut sample_sizes = Vec::with_capacity(n_frames);
341        for i in 0..n_frames {
342            let frame_byte_end = frame_byte_starts
343                .get(i + 1)
344                .copied()
345                .unwrap_or(value_bytes.len());
346            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
347        }
348        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
349
350        let (dictionary, mut compressor) = if !use_dictionary
351            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
352        {
353            // no dictionary
354            (None, zstd::bulk::Compressor::new(level)?)
355        } else {
356            // with dictionary
357            let max_dict_size = choose_max_dict_size(value_bytes.len());
358            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
359                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
360
361            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
362            (Some(ByteBuffer::from(dict)), compressor)
363        };
364
365        let mut frame_metas = vec![];
366        let mut frames = vec![];
367        for i in 0..n_frames {
368            let frame_byte_end = frame_byte_starts
369                .get(i + 1)
370                .copied()
371                .unwrap_or(value_bytes.len());
372
373            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
374            let compressed = compressor
375                .compress(uncompressed)
376                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
377            frame_metas.push(ZstdFrameMetadata {
378                uncompressed_size: uncompressed.len() as u64,
379                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
380            });
381            frames.push(ByteBuffer::from(compressed));
382        }
383
384        Ok(Frames {
385            dictionary,
386            frames,
387            frame_metas,
388        })
389    }
390
391    /// Creates a ZstdArray from a primitive array.
392    ///
393    /// # Arguments
394    /// * `parray` - The primitive array to compress
395    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
396    /// * `values_per_frame` - Number of values per frame (0 = single frame)
397    pub fn from_primitive(
398        parray: &PrimitiveArray,
399        level: i32,
400        values_per_frame: usize,
401    ) -> VortexResult<Self> {
402        Self::from_primitive_impl(parray, level, values_per_frame, true)
403    }
404
405    /// Creates a ZstdArray from a primitive array without using a dictionary.
406    ///
407    /// This is useful when the compressed data will be decompressed by systems
408    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
409    ///
410    /// Note: Without a dictionary, each frame is compressed independently.
411    /// Dictionaries are trained from sample data from previously seen frames,
412    /// to improve compression ratio.
413    ///
414    /// # Arguments
415    /// * `parray` - The primitive array to compress
416    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
417    /// * `values_per_frame` - Number of values per frame (0 = single frame)
418    pub fn from_primitive_without_dict(
419        parray: &PrimitiveArray,
420        level: i32,
421        values_per_frame: usize,
422    ) -> VortexResult<Self> {
423        Self::from_primitive_impl(parray, level, values_per_frame, false)
424    }
425
426    fn from_primitive_impl(
427        parray: &PrimitiveArray,
428        level: i32,
429        values_per_frame: usize,
430        use_dictionary: bool,
431    ) -> VortexResult<Self> {
432        let dtype = parray.dtype().clone();
433        let byte_width = parray.ptype().byte_width();
434
435        // We compress only the valid elements.
436        let values = collect_valid_primitive(parray)?;
437        let n_values = values.len();
438        let values_per_frame = if values_per_frame > 0 {
439            values_per_frame
440        } else {
441            n_values
442        };
443
444        let value_bytes = values.buffer_handle().try_to_host_sync()?;
445        // Align frames to buffer alignment. This is necessary for overaligned buffers.
446        let alignment = *value_bytes.alignment();
447        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
448
449        let frame_byte_starts = (0..n_values * byte_width)
450            .step_by(step_width)
451            .collect::<Vec<_>>();
452        let Frames {
453            dictionary,
454            frames,
455            frame_metas,
456        } = Self::compress_values(
457            &value_bytes,
458            &frame_byte_starts,
459            level,
460            values_per_frame,
461            n_values,
462            use_dictionary,
463        )?;
464
465        let metadata = ZstdMetadata {
466            dictionary_size: dictionary
467                .as_ref()
468                .map_or(0, |dict| dict.len())
469                .try_into()?,
470            frames: frame_metas,
471        };
472
473        Ok(ZstdArray::new(
474            dictionary,
475            frames,
476            dtype,
477            metadata,
478            parray.len(),
479            parray.validity().clone(),
480        ))
481    }
482
483    /// Creates a ZstdArray from a VarBinView array.
484    ///
485    /// # Arguments
486    /// * `vbv` - The VarBinView array to compress
487    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
488    /// * `values_per_frame` - Number of values per frame (0 = single frame)
489    pub fn from_var_bin_view(
490        vbv: &VarBinViewArray,
491        level: i32,
492        values_per_frame: usize,
493    ) -> VortexResult<Self> {
494        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
495    }
496
497    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
498    ///
499    /// This is useful when the compressed data will be decompressed by systems
500    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
501    ///
502    /// Note: Without a dictionary, each frame is compressed independently.
503    /// Dictionaries are trained from sample data from previously seen frames,
504    /// to improve compression ratio.
505    ///
506    /// # Arguments
507    /// * `vbv` - The VarBinView array to compress
508    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
509    /// * `values_per_frame` - Number of values per frame (0 = single frame)
510    pub fn from_var_bin_view_without_dict(
511        vbv: &VarBinViewArray,
512        level: i32,
513        values_per_frame: usize,
514    ) -> VortexResult<Self> {
515        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
516    }
517
518    fn from_var_bin_view_impl(
519        vbv: &VarBinViewArray,
520        level: i32,
521        values_per_frame: usize,
522        use_dictionary: bool,
523    ) -> VortexResult<Self> {
524        // Approach for strings: we prefix each string with its length as a u32.
525        // This is the same as what Parquet does. In some cases it may be better
526        // to separate the binary data and lengths as two separate streams, but
527        // this approach is simpler and can be best in cases when there is
528        // mutual information between strings and their lengths.
529        let dtype = vbv.dtype().clone();
530
531        // We compress only the valid elements.
532        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
533        let n_values = value_byte_indices.len();
534        let values_per_frame = if values_per_frame > 0 {
535            values_per_frame
536        } else {
537            n_values
538        };
539
540        let frame_byte_starts = (0..n_values)
541            .step_by(values_per_frame)
542            .map(|i| value_byte_indices[i])
543            .collect::<Vec<_>>();
544        let Frames {
545            dictionary,
546            frames,
547            frame_metas,
548        } = Self::compress_values(
549            &value_bytes,
550            &frame_byte_starts,
551            level,
552            values_per_frame,
553            n_values,
554            use_dictionary,
555        )?;
556
557        let metadata = ZstdMetadata {
558            dictionary_size: dictionary
559                .as_ref()
560                .map_or(0, |dict| dict.len())
561                .try_into()?,
562            frames: frame_metas,
563        };
564        Ok(ZstdArray::new(
565            dictionary,
566            frames,
567            dtype,
568            metadata,
569            vbv.len(),
570            vbv.validity().clone(),
571        ))
572    }
573
574    pub fn from_canonical(
575        canonical: &Canonical,
576        level: i32,
577        values_per_frame: usize,
578    ) -> VortexResult<Option<Self>> {
579        match canonical {
580            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
581                parray,
582                level,
583                values_per_frame,
584            )?)),
585            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
586                vbv,
587                level,
588                values_per_frame,
589            )?)),
590            _ => Ok(None),
591        }
592    }
593
594    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
595        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
596            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
597    }
598
599    fn byte_width(&self) -> usize {
600        if self.dtype.is_primitive() {
601            self.dtype.as_ptype().byte_width()
602        } else {
603            1
604        }
605    }
606
607    pub fn decompress(&self) -> VortexResult<ArrayRef> {
608        // To start, we figure out which frames we need to decompress, and with
609        // what row offset into the first such frame.
610        let byte_width = self.byte_width();
611        let slice_n_rows = self.slice_stop - self.slice_start;
612        let slice_value_indices = self
613            .unsliced_validity
614            .to_mask(self.unsliced_n_rows)
615            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
616
617        let slice_value_idx_start = slice_value_indices[0];
618        let slice_value_idx_stop = slice_value_indices[1];
619
620        let mut frames_to_decompress = vec![];
621        let mut value_idx_start = 0;
622        let mut uncompressed_size_to_decompress = 0;
623        let mut n_skipped_values = 0;
624        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
625            if value_idx_start >= slice_value_idx_stop {
626                break;
627            }
628
629            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
630                .vortex_expect("Uncompressed size must fit in usize");
631            let frame_n_values = if frame_meta.n_values == 0 {
632                // possibly older primitive-only metadata that just didn't store this
633                frame_uncompressed_size / byte_width
634            } else {
635                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
636            };
637
638            let value_idx_stop = value_idx_start + frame_n_values;
639            if value_idx_stop > slice_value_idx_start {
640                // we need this frame
641                frames_to_decompress.push(frame);
642                uncompressed_size_to_decompress += frame_uncompressed_size;
643            } else {
644                n_skipped_values += frame_n_values;
645            }
646            value_idx_start = value_idx_stop;
647        }
648
649        // then we actually decompress those frames
650        let mut decompressor = if let Some(dictionary) = &self.dictionary {
651            zstd::bulk::Decompressor::with_dictionary(dictionary)
652        } else {
653            zstd::bulk::Decompressor::new()
654        }
655        .vortex_expect("Decompressor encountered io error");
656        let mut decompressed = ByteBufferMut::with_capacity_aligned(
657            uncompressed_size_to_decompress,
658            Alignment::new(byte_width),
659        );
660        unsafe {
661            // safety: we immediately fill all bytes in the following loop,
662            // assuming our metadata's uncompressed size is correct
663            decompressed.set_len(uncompressed_size_to_decompress);
664        }
665        let mut uncompressed_start = 0;
666        for frame in frames_to_decompress {
667            let uncompressed_written = decompressor
668                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
669                .vortex_expect("error while decompressing zstd array");
670            uncompressed_start += uncompressed_written;
671        }
672        if uncompressed_start != uncompressed_size_to_decompress {
673            vortex_panic!(
674                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
675                uncompressed_size_to_decompress,
676                uncompressed_start
677            );
678        }
679
680        let decompressed = decompressed.freeze();
681        // Last, we slice the exact values requested out of the decompressed data.
682        let mut slice_validity = self
683            .unsliced_validity
684            .slice(self.slice_start..self.slice_stop)?;
685
686        // NOTE: this block handles setting the output type when the validity and DType disagree.
687        //
688        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
689        // the data frames. A ZSTD Array that was initialized must always hold onto its full
690        // validity bitmap, even if sliced to only include non-null values.
691        //
692        // We ensure that the validity of the decompressed array ALWAYS matches the validity
693        // implied by the DType.
694        if !self.dtype().is_nullable() && slice_validity != Validity::NonNullable {
695            assert!(
696                slice_validity.all_valid(slice_n_rows)?,
697                "ZSTD array expects to be non-nullable but there are nulls after decompression"
698            );
699
700            slice_validity = Validity::NonNullable;
701        } else if self.dtype.is_nullable() && slice_validity == Validity::NonNullable {
702            slice_validity = Validity::AllValid;
703        }
704        //
705        // END OF IMPORTANT BLOCK
706        //
707
708        match &self.dtype {
709            DType::Primitive(..) => {
710                let slice_values_buffer = decompressed.slice(
711                    (slice_value_idx_start - n_skipped_values) * byte_width
712                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
713                );
714                let primitive = PrimitiveArray::from_values_byte_buffer(
715                    slice_values_buffer,
716                    self.dtype.as_ptype(),
717                    slice_validity,
718                    slice_n_rows,
719                );
720
721                Ok(primitive.into_array())
722            }
723            DType::Binary(_) | DType::Utf8(_) => {
724                match slice_validity.to_mask(slice_n_rows).indices() {
725                    AllOr::All => {
726                        // the decompressed buffer is a bunch of interleaved u32 lengths
727                        // and strings of those lengths, we need to reconstruct the
728                        // views into those strings by passing through the buffer.
729                        let valid_views = reconstruct_views(&decompressed).slice(
730                            slice_value_idx_start - n_skipped_values
731                                ..slice_value_idx_stop - n_skipped_values,
732                        );
733
734                        // SAFETY: we properly construct the views inside `reconstruct_views`
735                        Ok(unsafe {
736                            VarBinViewArray::new_unchecked(
737                                valid_views,
738                                Arc::from([decompressed]),
739                                self.dtype.clone(),
740                                slice_validity,
741                            )
742                        }
743                        .into_array())
744                    }
745                    AllOr::None => Ok(ConstantArray::new(
746                        Scalar::null(self.dtype.clone()),
747                        slice_n_rows,
748                    )
749                    .into_array()),
750                    AllOr::Some(valid_indices) => {
751                        // the decompressed buffer is a bunch of interleaved u32 lengths
752                        // and strings of those lengths, we need to reconstruct the
753                        // views into those strings by passing through the buffer.
754                        let valid_views = reconstruct_views(&decompressed).slice(
755                            slice_value_idx_start - n_skipped_values
756                                ..slice_value_idx_stop - n_skipped_values,
757                        );
758
759                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
760                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
761                            views[*index] = view
762                        }
763
764                        // SAFETY: we properly construct the views inside `reconstruct_views`
765                        Ok(unsafe {
766                            VarBinViewArray::new_unchecked(
767                                views.freeze(),
768                                Arc::from([decompressed]),
769                                self.dtype.clone(),
770                                slice_validity,
771                            )
772                        }
773                        .into_array())
774                    }
775                }
776            }
777            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
778        }
779    }
780
781    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
782        let new_start = self.slice_start + start;
783        let new_stop = self.slice_start + stop;
784
785        assert!(
786            new_start <= self.slice_stop,
787            "new slice start {new_start} exceeds end {}",
788            self.slice_stop
789        );
790
791        assert!(
792            new_stop <= self.slice_stop,
793            "new slice stop {new_stop} exceeds end {}",
794            self.slice_stop
795        );
796
797        ZstdArray {
798            slice_start: self.slice_start + start,
799            slice_stop: self.slice_start + stop,
800            stats_set: Default::default(),
801            ..self.clone()
802        }
803    }
804
805    /// Consumes the array and returns its parts.
806    pub fn into_parts(self) -> ZstdArrayParts {
807        ZstdArrayParts {
808            dictionary: self.dictionary,
809            frames: self.frames,
810            metadata: self.metadata,
811            dtype: self.dtype,
812            validity: self.unsliced_validity,
813            n_rows: self.unsliced_n_rows,
814            slice_start: self.slice_start,
815            slice_stop: self.slice_stop,
816        }
817    }
818
819    pub(crate) fn dtype(&self) -> &DType {
820        &self.dtype
821    }
822
823    pub(crate) fn slice_start(&self) -> usize {
824        self.slice_start
825    }
826
827    pub(crate) fn slice_stop(&self) -> usize {
828        self.slice_stop
829    }
830
831    pub(crate) fn unsliced_n_rows(&self) -> usize {
832        self.unsliced_n_rows
833    }
834}
835
836impl ValiditySliceHelper for ZstdArray {
837    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
838        (&self.unsliced_validity, self.slice_start, self.slice_stop)
839    }
840}
841
842impl BaseArrayVTable<ZstdVTable> for ZstdVTable {
843    fn len(array: &ZstdArray) -> usize {
844        array.slice_stop - array.slice_start
845    }
846
847    fn dtype(array: &ZstdArray) -> &DType {
848        &array.dtype
849    }
850
851    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
852        array.stats_set.to_ref(array.as_ref())
853    }
854
855    fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
856        match &array.dictionary {
857            Some(dict) => {
858                true.hash(state);
859                dict.array_hash(state, precision);
860            }
861            None => {
862                false.hash(state);
863            }
864        }
865        for frame in &array.frames {
866            frame.array_hash(state, precision);
867        }
868        array.dtype.hash(state);
869        array.unsliced_validity.array_hash(state, precision);
870        array.unsliced_n_rows.hash(state);
871        array.slice_start.hash(state);
872        array.slice_stop.hash(state);
873    }
874
875    fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
876        if !match (&array.dictionary, &other.dictionary) {
877            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
878            (None, None) => true,
879            _ => false,
880        } {
881            return false;
882        }
883        if array.frames.len() != other.frames.len() {
884            return false;
885        }
886        for (a, b) in array.frames.iter().zip(&other.frames) {
887            if !a.array_eq(b, precision) {
888                return false;
889            }
890        }
891        array.dtype == other.dtype
892            && array
893                .unsliced_validity
894                .array_eq(&other.unsliced_validity, precision)
895            && array.unsliced_n_rows == other.unsliced_n_rows
896            && array.slice_start == other.slice_start
897            && array.slice_stop == other.slice_stop
898    }
899}
900
901impl OperationsVTable<ZstdVTable> for ZstdVTable {
902    fn scalar_at(array: &ZstdArray, index: usize) -> VortexResult<Scalar> {
903        array._slice(index, index + 1).decompress()?.scalar_at(0)
904    }
905}
906
907impl VisitorVTable<ZstdVTable> for ZstdVTable {
908    fn visit_buffers(array: &ZstdArray, visitor: &mut dyn ArrayBufferVisitor) {
909        if let Some(buffer) = &array.dictionary {
910            visitor.visit_buffer_handle("dictionary", &BufferHandle::new_host(buffer.clone()));
911        }
912        for (i, buffer) in array.frames.iter().enumerate() {
913            visitor.visit_buffer_handle(
914                &format!("frame_{i}"),
915                &BufferHandle::new_host(buffer.clone()),
916            );
917        }
918    }
919
920    fn nbuffers(array: &ZstdArray) -> usize {
921        array.dictionary.is_some() as usize + array.frames.len()
922    }
923
924    fn visit_children(array: &ZstdArray, visitor: &mut dyn ArrayChildVisitor) {
925        visitor.visit_validity(&array.unsliced_validity, array.unsliced_n_rows());
926    }
927
928    fn nchildren(array: &ZstdArray) -> usize {
929        validity_nchildren(&array.unsliced_validity)
930    }
931}