Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::DynArray;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionResult;
17use vortex_array::IntoArray;
18use vortex_array::LEGACY_SESSION;
19use vortex_array::Precision;
20use vortex_array::ProstMetadata;
21use vortex_array::ToCanonical;
22use vortex_array::VortexSessionExecute;
23use vortex_array::accessor::ArrayAccessor;
24use vortex_array::arrays::ConstantArray;
25use vortex_array::arrays::PrimitiveArray;
26use vortex_array::arrays::VarBinViewArray;
27use vortex_array::arrays::varbinview::build_views::BinaryView;
28use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
29use vortex_array::buffer::BufferHandle;
30use vortex_array::dtype::DType;
31use vortex_array::scalar::Scalar;
32use vortex_array::serde::ArrayChildren;
33use vortex_array::stats::ArrayStats;
34use vortex_array::stats::StatsSetRef;
35use vortex_array::validity::Validity;
36use vortex_array::vtable;
37use vortex_array::vtable::Array;
38use vortex_array::vtable::ArrayId;
39use vortex_array::vtable::OperationsVTable;
40use vortex_array::vtable::VTable;
41use vortex_array::vtable::ValidityHelper;
42use vortex_array::vtable::ValiditySliceHelper;
43use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
44use vortex_array::vtable::validity_nchildren;
45use vortex_array::vtable::validity_to_child;
46use vortex_buffer::Alignment;
47use vortex_buffer::Buffer;
48use vortex_buffer::BufferMut;
49use vortex_buffer::ByteBuffer;
50use vortex_buffer::ByteBufferMut;
51use vortex_error::VortexError;
52use vortex_error::VortexExpect;
53use vortex_error::VortexResult;
54use vortex_error::vortex_bail;
55use vortex_error::vortex_ensure;
56use vortex_error::vortex_err;
57use vortex_error::vortex_panic;
58use vortex_mask::AllOr;
59use vortex_session::VortexSession;
60
61use crate::ZstdFrameMetadata;
62use crate::ZstdMetadata;
63
64// Zstd doesn't support training dictionaries on very few samples.
65const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
66type ViewLen = u32;
67
68// Overall approach here:
69// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
70// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
71// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
72// want somewhat faster access to slices or individual rows, allowing us to only
73// decompress the necessary frames.
74
75// Visually, during decompression, we have an interval of frames we're
76// decompressing and a tighter interval of the slice we actually care about.
77// |=============values (all valid elements)==============|
78// |<-skipped_uncompressed->|----decompressed-------------|
79//                              |------slice-------|
80//                              ^                  ^
81// |<-slice_uncompressed_start->|                  |
82// |<------------slice_uncompressed_stop---------->|
83// We then insert these values to the correct position using a primitive array
84// constructor.
85
86vtable!(Zstd);
87
88impl VTable for Zstd {
89    type Array = ZstdArray;
90
91    type Metadata = ProstMetadata<ZstdMetadata>;
92    type OperationsVTable = Self;
93    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
94
95    fn vtable(_array: &Self::Array) -> &Self {
96        &Zstd
97    }
98
99    fn id(&self) -> ArrayId {
100        Self::ID
101    }
102
103    fn len(array: &ZstdArray) -> usize {
104        array.slice_stop - array.slice_start
105    }
106
107    fn dtype(array: &ZstdArray) -> &DType {
108        &array.dtype
109    }
110
111    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
112        array.stats_set.to_ref(array.as_ref())
113    }
114
115    fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
116        match &array.dictionary {
117            Some(dict) => {
118                true.hash(state);
119                dict.array_hash(state, precision);
120            }
121            None => {
122                false.hash(state);
123            }
124        }
125        for frame in &array.frames {
126            frame.array_hash(state, precision);
127        }
128        array.dtype.hash(state);
129        array.unsliced_validity.array_hash(state, precision);
130        array.unsliced_n_rows.hash(state);
131        array.slice_start.hash(state);
132        array.slice_stop.hash(state);
133    }
134
135    fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
136        if !match (&array.dictionary, &other.dictionary) {
137            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
138            (None, None) => true,
139            _ => false,
140        } {
141            return false;
142        }
143        if array.frames.len() != other.frames.len() {
144            return false;
145        }
146        for (a, b) in array.frames.iter().zip(&other.frames) {
147            if !a.array_eq(b, precision) {
148                return false;
149            }
150        }
151        array.dtype == other.dtype
152            && array
153                .unsliced_validity
154                .array_eq(&other.unsliced_validity, precision)
155            && array.unsliced_n_rows == other.unsliced_n_rows
156            && array.slice_start == other.slice_start
157            && array.slice_stop == other.slice_stop
158    }
159
160    fn nbuffers(array: &ZstdArray) -> usize {
161        array.dictionary.is_some() as usize + array.frames.len()
162    }
163
164    fn buffer(array: &ZstdArray, idx: usize) -> BufferHandle {
165        if let Some(dict) = &array.dictionary {
166            if idx == 0 {
167                return BufferHandle::new_host(dict.clone());
168            }
169            BufferHandle::new_host(array.frames[idx - 1].clone())
170        } else {
171            BufferHandle::new_host(array.frames[idx].clone())
172        }
173    }
174
175    fn buffer_name(array: &ZstdArray, idx: usize) -> Option<String> {
176        if array.dictionary.is_some() {
177            if idx == 0 {
178                Some("dictionary".to_string())
179            } else {
180                Some(format!("frame_{}", idx - 1))
181            }
182        } else {
183            Some(format!("frame_{idx}"))
184        }
185    }
186
187    fn nchildren(array: &ZstdArray) -> usize {
188        validity_nchildren(&array.unsliced_validity)
189    }
190
191    fn child(array: &ZstdArray, idx: usize) -> ArrayRef {
192        validity_to_child(&array.unsliced_validity, array.unsliced_n_rows)
193            .unwrap_or_else(|| vortex_panic!("ZstdArray child index {idx} out of bounds"))
194    }
195
196    fn child_name(_array: &ZstdArray, idx: usize) -> String {
197        match idx {
198            0 => "validity".to_string(),
199            _ => vortex_panic!("ZstdArray child_name index {idx} out of bounds"),
200        }
201    }
202
203    fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
204        Ok(ProstMetadata(array.metadata.clone()))
205    }
206
207    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
208        Ok(Some(metadata.0.encode_to_vec()))
209    }
210
211    fn deserialize(
212        bytes: &[u8],
213        _dtype: &DType,
214        _len: usize,
215        _buffers: &[BufferHandle],
216        _session: &VortexSession,
217    ) -> VortexResult<Self::Metadata> {
218        Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
219    }
220
221    fn build(
222        dtype: &DType,
223        len: usize,
224        metadata: &Self::Metadata,
225        buffers: &[BufferHandle],
226        children: &dyn ArrayChildren,
227    ) -> VortexResult<ZstdArray> {
228        let validity = if children.is_empty() {
229            Validity::from(dtype.nullability())
230        } else if children.len() == 1 {
231            let validity = children.get(0, &Validity::DTYPE, len)?;
232            Validity::Array(validity)
233        } else {
234            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
235        };
236
237        let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
238            // no dictionary
239            (
240                None,
241                buffers
242                    .iter()
243                    .map(|b| b.clone().try_to_host_sync())
244                    .collect::<VortexResult<Vec<_>>>()?,
245            )
246        } else {
247            // with dictionary
248            (
249                Some(buffers[0].clone().try_to_host_sync()?),
250                buffers[1..]
251                    .iter()
252                    .map(|b| b.clone().try_to_host_sync())
253                    .collect::<VortexResult<Vec<_>>>()?,
254            )
255        };
256
257        Ok(ZstdArray::new(
258            dictionary_buffer,
259            compressed_buffers,
260            dtype.clone(),
261            metadata.0.clone(),
262            len,
263            validity,
264        ))
265    }
266
267    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
268        vortex_ensure!(
269            children.len() <= 1,
270            "ZstdArray expects at most 1 child (validity), got {}",
271            children.len()
272        );
273
274        array.unsliced_validity = if children.is_empty() {
275            Validity::from(array.dtype.nullability())
276        } else {
277            Validity::Array(children.into_iter().next().vortex_expect("checked"))
278        };
279
280        Ok(())
281    }
282
283    fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
284        array
285            .decompress(ctx)?
286            .execute::<ArrayRef>(ctx)
287            .map(ExecutionResult::done)
288    }
289
290    fn reduce_parent(
291        array: &Array<Self>,
292        parent: &ArrayRef,
293        child_idx: usize,
294    ) -> VortexResult<Option<ArrayRef>> {
295        crate::rules::RULES.evaluate(array, parent, child_idx)
296    }
297}
298
299#[derive(Clone, Debug)]
300pub struct Zstd;
301
302impl Zstd {
303    pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
304}
305
306#[derive(Clone, Debug)]
307pub struct ZstdArray {
308    pub(crate) dictionary: Option<ByteBuffer>,
309    pub(crate) frames: Vec<ByteBuffer>,
310    pub(crate) metadata: ZstdMetadata,
311    dtype: DType,
312    pub(crate) unsliced_validity: Validity,
313    unsliced_n_rows: usize,
314    stats_set: ArrayStats,
315    slice_start: usize,
316    slice_stop: usize,
317}
318
319/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
320#[derive(Debug)]
321pub struct ZstdArrayParts {
322    /// The optional dictionary used for compression.
323    pub dictionary: Option<ByteBuffer>,
324    /// The compressed frames.
325    pub frames: Vec<ByteBuffer>,
326    /// The compression metadata.
327    pub metadata: ZstdMetadata,
328    /// The data type of the uncompressed array.
329    pub dtype: DType,
330    /// The validity of the uncompressed array.
331    pub validity: Validity,
332    /// The number of rows in the uncompressed array.
333    pub n_rows: usize,
334    /// Slice start offset.
335    pub slice_start: usize,
336    /// Slice stop offset.
337    pub slice_stop: usize,
338}
339
340struct Frames {
341    dictionary: Option<ByteBuffer>,
342    frames: Vec<ByteBuffer>,
343    frame_metas: Vec<ZstdFrameMetadata>,
344}
345
346fn choose_max_dict_size(uncompressed_size: usize) -> usize {
347    // following recommendations from
348    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
349    // that is, 1/100 the data size, up to 100kB.
350    // It appears that zstd can't train dictionaries with <256 bytes.
351    (uncompressed_size / 100).clamp(256, 100 * 1024)
352}
353
354fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
355    let mask = parray.validity_mask()?;
356    Ok(parray.clone().into_array().filter(mask)?.to_primitive())
357}
358
359fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
360    let mask = vbv.validity_mask()?;
361    let buffer_and_value_byte_indices = match mask.bit_buffer() {
362        AllOr::None => (Buffer::empty(), Vec::new()),
363        _ => {
364            let mut buffer = BufferMut::with_capacity(
365                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
366                    + mask.true_count() * size_of::<ViewLen>(),
367            );
368            let mut value_byte_indices = Vec::new();
369            vbv.with_iterator(|iterator| {
370                // by flattening, we should omit nulls
371                for value in iterator.flatten() {
372                    value_byte_indices.push(buffer.len());
373                    // here's where we write the string lengths
374                    buffer
375                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
376                    buffer.extend_from_slice(value);
377                }
378                Ok::<_, VortexError>(())
379            })?;
380            (buffer.freeze(), value_byte_indices)
381        }
382    };
383    Ok(buffer_and_value_byte_indices)
384}
385
386/// Reconstruct BinaryView structs from length-prefixed byte data.
387///
388/// The buffer contains interleaved u32 lengths (little-endian) and string data.
389/// When the cumulative data exceeds `max_buffer_len`, the buffer is split (zero-copy) into
390/// multiple segments so that BinaryView's u32 offsets can address all data.
391///
392/// Pass [`MAX_BUFFER_LEN`] for `max_buffer_len` in production; a smaller value can be used in
393/// tests to exercise the splitting path without allocating >2 GiB.
394pub fn reconstruct_views(
395    buffer: &ByteBuffer,
396    max_buffer_len: usize,
397) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
398    let mut views = BufferMut::<BinaryView>::empty();
399    let mut buffers = Vec::new();
400    let mut segment_start: usize = 0;
401    let mut offset = 0;
402
403    while offset < buffer.len() {
404        let str_len = ViewLen::from_le_bytes(
405            buffer
406                .get(offset..offset + size_of::<ViewLen>())
407                .vortex_expect("corrupted zstd length")
408                .try_into()
409                .ok()
410                .vortex_expect("must fit ViewLen size"),
411        ) as usize;
412
413        let value_data_offset = offset + size_of::<ViewLen>();
414        let local_offset = value_data_offset - segment_start;
415
416        if local_offset + str_len > max_buffer_len && offset > segment_start {
417            buffers.push(buffer.slice(segment_start..offset));
418            segment_start = offset;
419        }
420
421        let local_offset = u32::try_from(value_data_offset - segment_start)
422            .vortex_expect("local offset within segment must fit in u32");
423        let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
424        let value = &buffer[value_data_offset..value_data_offset + str_len];
425        views.push(BinaryView::make_view(value, buf_index, local_offset));
426        offset = value_data_offset + str_len;
427    }
428
429    if segment_start < buffer.len() {
430        buffers.push(buffer.slice(segment_start..buffer.len()));
431    }
432
433    (buffers, views.freeze())
434}
435
436impl ZstdArray {
437    pub fn new(
438        dictionary: Option<ByteBuffer>,
439        frames: Vec<ByteBuffer>,
440        dtype: DType,
441        metadata: ZstdMetadata,
442        n_rows: usize,
443        validity: Validity,
444    ) -> Self {
445        Self {
446            dictionary,
447            frames,
448            metadata,
449            dtype,
450            unsliced_validity: validity,
451            unsliced_n_rows: n_rows,
452            stats_set: Default::default(),
453            slice_start: 0,
454            slice_stop: n_rows,
455        }
456    }
457
458    fn compress_values(
459        value_bytes: &ByteBuffer,
460        frame_byte_starts: &[usize],
461        level: i32,
462        values_per_frame: usize,
463        n_values: usize,
464        use_dictionary: bool,
465    ) -> VortexResult<Frames> {
466        let n_frames = frame_byte_starts.len();
467
468        // Would-be sample sizes if we end up applying zstd dictionary
469        let mut sample_sizes = Vec::with_capacity(n_frames);
470        for i in 0..n_frames {
471            let frame_byte_end = frame_byte_starts
472                .get(i + 1)
473                .copied()
474                .unwrap_or(value_bytes.len());
475            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
476        }
477        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
478
479        let (dictionary, mut compressor) = if !use_dictionary
480            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
481        {
482            // no dictionary
483            (None, zstd::bulk::Compressor::new(level)?)
484        } else {
485            // with dictionary
486            let max_dict_size = choose_max_dict_size(value_bytes.len());
487            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
488                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
489
490            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
491            (Some(ByteBuffer::from(dict)), compressor)
492        };
493
494        let mut frame_metas = vec![];
495        let mut frames = vec![];
496        for i in 0..n_frames {
497            let frame_byte_end = frame_byte_starts
498                .get(i + 1)
499                .copied()
500                .unwrap_or(value_bytes.len());
501
502            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
503            let compressed = compressor
504                .compress(uncompressed)
505                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
506            frame_metas.push(ZstdFrameMetadata {
507                uncompressed_size: uncompressed.len() as u64,
508                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
509            });
510            frames.push(ByteBuffer::from(compressed));
511        }
512
513        Ok(Frames {
514            dictionary,
515            frames,
516            frame_metas,
517        })
518    }
519
520    /// Creates a ZstdArray from a primitive array.
521    ///
522    /// # Arguments
523    /// * `parray` - The primitive array to compress
524    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
525    /// * `values_per_frame` - Number of values per frame (0 = single frame)
526    pub fn from_primitive(
527        parray: &PrimitiveArray,
528        level: i32,
529        values_per_frame: usize,
530    ) -> VortexResult<Self> {
531        Self::from_primitive_impl(parray, level, values_per_frame, true)
532    }
533
534    /// Creates a ZstdArray from a primitive array without using a dictionary.
535    ///
536    /// This is useful when the compressed data will be decompressed by systems
537    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
538    ///
539    /// Note: Without a dictionary, each frame is compressed independently.
540    /// Dictionaries are trained from sample data from previously seen frames,
541    /// to improve compression ratio.
542    ///
543    /// # Arguments
544    /// * `parray` - The primitive array to compress
545    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
546    /// * `values_per_frame` - Number of values per frame (0 = single frame)
547    pub fn from_primitive_without_dict(
548        parray: &PrimitiveArray,
549        level: i32,
550        values_per_frame: usize,
551    ) -> VortexResult<Self> {
552        Self::from_primitive_impl(parray, level, values_per_frame, false)
553    }
554
555    fn from_primitive_impl(
556        parray: &PrimitiveArray,
557        level: i32,
558        values_per_frame: usize,
559        use_dictionary: bool,
560    ) -> VortexResult<Self> {
561        let dtype = parray.dtype().clone();
562        let byte_width = parray.ptype().byte_width();
563
564        // We compress only the valid elements.
565        let values = collect_valid_primitive(parray)?;
566        let n_values = values.len();
567        let values_per_frame = if values_per_frame > 0 {
568            values_per_frame
569        } else {
570            n_values
571        };
572
573        let value_bytes = values.buffer_handle().try_to_host_sync()?;
574        // Align frames to buffer alignment. This is necessary for overaligned buffers.
575        let alignment = *value_bytes.alignment();
576        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
577
578        let frame_byte_starts = (0..n_values * byte_width)
579            .step_by(step_width)
580            .collect::<Vec<_>>();
581        let Frames {
582            dictionary,
583            frames,
584            frame_metas,
585        } = Self::compress_values(
586            &value_bytes,
587            &frame_byte_starts,
588            level,
589            values_per_frame,
590            n_values,
591            use_dictionary,
592        )?;
593
594        let metadata = ZstdMetadata {
595            dictionary_size: dictionary
596                .as_ref()
597                .map_or(0, |dict| dict.len())
598                .try_into()?,
599            frames: frame_metas,
600        };
601
602        Ok(ZstdArray::new(
603            dictionary,
604            frames,
605            dtype,
606            metadata,
607            parray.len(),
608            parray.validity().clone(),
609        ))
610    }
611
612    /// Creates a ZstdArray from a VarBinView array.
613    ///
614    /// # Arguments
615    /// * `vbv` - The VarBinView array to compress
616    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
617    /// * `values_per_frame` - Number of values per frame (0 = single frame)
618    pub fn from_var_bin_view(
619        vbv: &VarBinViewArray,
620        level: i32,
621        values_per_frame: usize,
622    ) -> VortexResult<Self> {
623        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
624    }
625
626    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
627    ///
628    /// This is useful when the compressed data will be decompressed by systems
629    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
630    ///
631    /// Note: Without a dictionary, each frame is compressed independently.
632    /// Dictionaries are trained from sample data from previously seen frames,
633    /// to improve compression ratio.
634    ///
635    /// # Arguments
636    /// * `vbv` - The VarBinView array to compress
637    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
638    /// * `values_per_frame` - Number of values per frame (0 = single frame)
639    pub fn from_var_bin_view_without_dict(
640        vbv: &VarBinViewArray,
641        level: i32,
642        values_per_frame: usize,
643    ) -> VortexResult<Self> {
644        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
645    }
646
647    fn from_var_bin_view_impl(
648        vbv: &VarBinViewArray,
649        level: i32,
650        values_per_frame: usize,
651        use_dictionary: bool,
652    ) -> VortexResult<Self> {
653        // Approach for strings: we prefix each string with its length as a u32.
654        // This is the same as what Parquet does. In some cases it may be better
655        // to separate the binary data and lengths as two separate streams, but
656        // this approach is simpler and can be best in cases when there is
657        // mutual information between strings and their lengths.
658        let dtype = vbv.dtype().clone();
659
660        // We compress only the valid elements.
661        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
662        let n_values = value_byte_indices.len();
663        let values_per_frame = if values_per_frame > 0 {
664            values_per_frame
665        } else {
666            n_values
667        };
668
669        let frame_byte_starts = (0..n_values)
670            .step_by(values_per_frame)
671            .map(|i| value_byte_indices[i])
672            .collect::<Vec<_>>();
673        let Frames {
674            dictionary,
675            frames,
676            frame_metas,
677        } = Self::compress_values(
678            &value_bytes,
679            &frame_byte_starts,
680            level,
681            values_per_frame,
682            n_values,
683            use_dictionary,
684        )?;
685
686        let metadata = ZstdMetadata {
687            dictionary_size: dictionary
688                .as_ref()
689                .map_or(0, |dict| dict.len())
690                .try_into()?,
691            frames: frame_metas,
692        };
693        Ok(ZstdArray::new(
694            dictionary,
695            frames,
696            dtype,
697            metadata,
698            vbv.len(),
699            vbv.validity().clone(),
700        ))
701    }
702
703    pub fn from_canonical(
704        canonical: &Canonical,
705        level: i32,
706        values_per_frame: usize,
707    ) -> VortexResult<Option<Self>> {
708        match canonical {
709            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
710                parray,
711                level,
712                values_per_frame,
713            )?)),
714            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
715                vbv,
716                level,
717                values_per_frame,
718            )?)),
719            _ => Ok(None),
720        }
721    }
722
723    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
724        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
725            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
726    }
727
728    fn byte_width(&self) -> usize {
729        if self.dtype.is_primitive() {
730            self.dtype.as_ptype().byte_width()
731        } else {
732            1
733        }
734    }
735
736    pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
737        // To start, we figure out which frames we need to decompress, and with
738        // what row offset into the first such frame.
739        let byte_width = self.byte_width();
740        let slice_n_rows = self.slice_stop - self.slice_start;
741        let slice_value_indices = self
742            .unsliced_validity
743            .execute_mask(self.unsliced_n_rows, ctx)?
744            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
745
746        let slice_value_idx_start = slice_value_indices[0];
747        let slice_value_idx_stop = slice_value_indices[1];
748
749        let mut frames_to_decompress = vec![];
750        let mut value_idx_start = 0;
751        let mut uncompressed_size_to_decompress = 0;
752        let mut n_skipped_values = 0;
753        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
754            if value_idx_start >= slice_value_idx_stop {
755                break;
756            }
757
758            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
759                .vortex_expect("Uncompressed size must fit in usize");
760            let frame_n_values = if frame_meta.n_values == 0 {
761                // possibly older primitive-only metadata that just didn't store this
762                frame_uncompressed_size / byte_width
763            } else {
764                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
765            };
766
767            let value_idx_stop = value_idx_start + frame_n_values;
768            if value_idx_stop > slice_value_idx_start {
769                // we need this frame
770                frames_to_decompress.push(frame);
771                uncompressed_size_to_decompress += frame_uncompressed_size;
772            } else {
773                n_skipped_values += frame_n_values;
774            }
775            value_idx_start = value_idx_stop;
776        }
777
778        // then we actually decompress those frames
779        let mut decompressor = if let Some(dictionary) = &self.dictionary {
780            zstd::bulk::Decompressor::with_dictionary(dictionary)?
781        } else {
782            zstd::bulk::Decompressor::new()?
783        };
784        let mut decompressed = ByteBufferMut::with_capacity_aligned(
785            uncompressed_size_to_decompress,
786            Alignment::new(byte_width),
787        );
788        unsafe {
789            // safety: we immediately fill all bytes in the following loop,
790            // assuming our metadata's uncompressed size is correct
791            decompressed.set_len(uncompressed_size_to_decompress);
792        }
793        let mut uncompressed_start = 0;
794        for frame in frames_to_decompress {
795            let uncompressed_written = decompressor
796                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
797            uncompressed_start += uncompressed_written;
798        }
799        if uncompressed_start != uncompressed_size_to_decompress {
800            vortex_panic!(
801                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
802                uncompressed_size_to_decompress,
803                uncompressed_start
804            );
805        }
806
807        let decompressed = decompressed.freeze();
808        // Last, we slice the exact values requested out of the decompressed data.
809        let mut slice_validity = self
810            .unsliced_validity
811            .slice(self.slice_start..self.slice_stop)?;
812
813        // NOTE: this block handles setting the output type when the validity and DType disagree.
814        //
815        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
816        // the data frames. A ZSTD Array that was initialized must always hold onto its full
817        // validity bitmap, even if sliced to only include non-null values.
818        //
819        // We ensure that the validity of the decompressed array ALWAYS matches the validity
820        // implied by the DType.
821        if !self.dtype().is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
822            assert!(
823                matches!(slice_validity, Validity::AllValid),
824                "ZSTD array expects to be non-nullable but there are nulls after decompression"
825            );
826
827            slice_validity = Validity::NonNullable;
828        } else if self.dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
829            slice_validity = Validity::AllValid;
830        }
831        //
832        // END OF IMPORTANT BLOCK
833        //
834
835        match &self.dtype {
836            DType::Primitive(..) => {
837                let slice_values_buffer = decompressed.slice(
838                    (slice_value_idx_start - n_skipped_values) * byte_width
839                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
840                );
841                let primitive = PrimitiveArray::from_values_byte_buffer(
842                    slice_values_buffer,
843                    self.dtype.as_ptype(),
844                    slice_validity,
845                    slice_n_rows,
846                );
847
848                Ok(primitive.into_array())
849            }
850            DType::Binary(_) | DType::Utf8(_) => {
851                match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
852                    AllOr::All => {
853                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
854                        let valid_views = all_views.slice(
855                            slice_value_idx_start - n_skipped_values
856                                ..slice_value_idx_stop - n_skipped_values,
857                        );
858
859                        // SAFETY: we properly construct the views inside `reconstruct_views`
860                        Ok(unsafe {
861                            VarBinViewArray::new_unchecked(
862                                valid_views,
863                                Arc::from(buffers),
864                                self.dtype.clone(),
865                                slice_validity,
866                            )
867                        }
868                        .into_array())
869                    }
870                    AllOr::None => Ok(ConstantArray::new(
871                        Scalar::null(self.dtype.clone()),
872                        slice_n_rows,
873                    )
874                    .into_array()),
875                    AllOr::Some(valid_indices) => {
876                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
877                        let valid_views = all_views.slice(
878                            slice_value_idx_start - n_skipped_values
879                                ..slice_value_idx_stop - n_skipped_values,
880                        );
881
882                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
883                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
884                            views[*index] = view
885                        }
886
887                        // SAFETY: we properly construct the views inside `reconstruct_views`
888                        Ok(unsafe {
889                            VarBinViewArray::new_unchecked(
890                                views.freeze(),
891                                Arc::from(buffers),
892                                self.dtype.clone(),
893                                slice_validity,
894                            )
895                        }
896                        .into_array())
897                    }
898                }
899            }
900            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
901        }
902    }
903
904    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
905        let new_start = self.slice_start + start;
906        let new_stop = self.slice_start + stop;
907
908        assert!(
909            new_start <= self.slice_stop,
910            "new slice start {new_start} exceeds end {}",
911            self.slice_stop
912        );
913
914        assert!(
915            new_stop <= self.slice_stop,
916            "new slice stop {new_stop} exceeds end {}",
917            self.slice_stop
918        );
919
920        ZstdArray {
921            slice_start: self.slice_start + start,
922            slice_stop: self.slice_start + stop,
923            stats_set: Default::default(),
924            ..self.clone()
925        }
926    }
927
928    /// Consumes the array and returns its parts.
929    pub fn into_parts(self) -> ZstdArrayParts {
930        ZstdArrayParts {
931            dictionary: self.dictionary,
932            frames: self.frames,
933            metadata: self.metadata,
934            dtype: self.dtype,
935            validity: self.unsliced_validity,
936            n_rows: self.unsliced_n_rows,
937            slice_start: self.slice_start,
938            slice_stop: self.slice_stop,
939        }
940    }
941
942    pub(crate) fn dtype(&self) -> &DType {
943        &self.dtype
944    }
945
946    pub(crate) fn slice_start(&self) -> usize {
947        self.slice_start
948    }
949
950    pub(crate) fn slice_stop(&self) -> usize {
951        self.slice_stop
952    }
953
954    pub(crate) fn unsliced_n_rows(&self) -> usize {
955        self.unsliced_n_rows
956    }
957}
958
959impl ValiditySliceHelper for ZstdArray {
960    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
961        (&self.unsliced_validity, self.slice_start, self.slice_stop)
962    }
963}
964
965impl OperationsVTable<Zstd> for Zstd {
966    fn scalar_at(array: &ZstdArray, index: usize, _ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
967        let mut ctx = LEGACY_SESSION.create_execution_ctx();
968        array
969            ._slice(index, index + 1)
970            .decompress(&mut ctx)?
971            .scalar_at(0)
972    }
973}
974
975#[cfg(test)]
976#[allow(clippy::cast_possible_truncation)]
977mod tests {
978    use vortex_buffer::ByteBuffer;
979
980    use super::reconstruct_views;
981    use crate::array::BinaryView;
982
983    /// Build a Zstd-style interleaved buffer: [u32-LE length][string bytes] repeated.
984    fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
985        let mut buf = Vec::new();
986        for s in strings {
987            let len = s.len() as u32;
988            buf.extend_from_slice(&len.to_le_bytes());
989            buf.extend_from_slice(s);
990        }
991        ByteBuffer::copy_from(buf.as_slice())
992    }
993
994    #[test]
995    fn test_reconstruct_views_no_split() {
996        let strings: &[&[u8]] = &[b"hello", b"world"];
997        let buf = make_interleaved(strings);
998        let (buffers, views) = reconstruct_views(&buf, 1024);
999
1000        assert_eq!(buffers.len(), 1);
1001        assert_eq!(views.len(), 2);
1002        // Each entry: [u32 len (4 bytes)][data], so offsets are 4 and 4+5+4=13
1003        assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1004        assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1005    }
1006
1007    #[test]
1008    fn test_reconstruct_views_split_across_segments() {
1009        // "aaaaaaaaaaaaa" (13 bytes) and "bbbbbbbbbbbbb" (13 bytes).
1010        // Each entry occupies 4 (length prefix) + 13 (data) = 17 bytes.
1011        // With max_buffer_len=20, the second entry's data (offset 4+13+4=21) exceeds the limit,
1012        // so it rolls into a second segment.
1013        let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1014        let buf = make_interleaved(strings);
1015        let (buffers, views) = reconstruct_views(&buf, 20);
1016
1017        assert_eq!(buffers.len(), 2);
1018        assert_eq!(views.len(), 2);
1019        assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1020        // Second entry starts a new segment at byte 17 (the length prefix), so local offset = 4.
1021        assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1022    }
1023}