Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::DynArray;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionResult;
17use vortex_array::IntoArray;
18use vortex_array::LEGACY_SESSION;
19use vortex_array::Precision;
20use vortex_array::ProstMetadata;
21use vortex_array::ToCanonical;
22use vortex_array::VortexSessionExecute;
23use vortex_array::accessor::ArrayAccessor;
24use vortex_array::arrays::ConstantArray;
25use vortex_array::arrays::PrimitiveArray;
26use vortex_array::arrays::VarBinViewArray;
27use vortex_array::arrays::varbinview::build_views::BinaryView;
28use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
29use vortex_array::buffer::BufferHandle;
30use vortex_array::dtype::DType;
31use vortex_array::scalar::Scalar;
32use vortex_array::serde::ArrayChildren;
33use vortex_array::stats::ArrayStats;
34use vortex_array::stats::StatsSetRef;
35use vortex_array::validity::Validity;
36use vortex_array::vtable;
37use vortex_array::vtable::ArrayId;
38use vortex_array::vtable::OperationsVTable;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityHelper;
41use vortex_array::vtable::ValiditySliceHelper;
42use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
43use vortex_array::vtable::validity_nchildren;
44use vortex_array::vtable::validity_to_child;
45use vortex_buffer::Alignment;
46use vortex_buffer::Buffer;
47use vortex_buffer::BufferMut;
48use vortex_buffer::ByteBuffer;
49use vortex_buffer::ByteBufferMut;
50use vortex_error::VortexError;
51use vortex_error::VortexExpect;
52use vortex_error::VortexResult;
53use vortex_error::vortex_bail;
54use vortex_error::vortex_ensure;
55use vortex_error::vortex_err;
56use vortex_error::vortex_panic;
57use vortex_mask::AllOr;
58use vortex_session::VortexSession;
59
60use crate::ZstdFrameMetadata;
61use crate::ZstdMetadata;
62
63// Zstd doesn't support training dictionaries on very few samples.
64const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
65type ViewLen = u32;
66
67// Overall approach here:
68// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
69// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
70// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
71// want somewhat faster access to slices or individual rows, allowing us to only
72// decompress the necessary frames.
73
74// Visually, during decompression, we have an interval of frames we're
75// decompressing and a tighter interval of the slice we actually care about.
76// |=============values (all valid elements)==============|
77// |<-skipped_uncompressed->|----decompressed-------------|
78//                              |------slice-------|
79//                              ^                  ^
80// |<-slice_uncompressed_start->|                  |
81// |<------------slice_uncompressed_stop---------->|
82// We then insert these values to the correct position using a primitive array
83// constructor.
84
85vtable!(Zstd);
86
87impl VTable for Zstd {
88    type Array = ZstdArray;
89
90    type Metadata = ProstMetadata<ZstdMetadata>;
91    type OperationsVTable = Self;
92    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
93
94    fn vtable(_array: &Self::Array) -> &Self {
95        &Zstd
96    }
97
98    fn id(&self) -> ArrayId {
99        Self::ID
100    }
101
102    fn len(array: &ZstdArray) -> usize {
103        array.slice_stop - array.slice_start
104    }
105
106    fn dtype(array: &ZstdArray) -> &DType {
107        &array.dtype
108    }
109
110    fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
111        array.stats_set.to_ref(array.as_ref())
112    }
113
114    fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
115        match &array.dictionary {
116            Some(dict) => {
117                true.hash(state);
118                dict.array_hash(state, precision);
119            }
120            None => {
121                false.hash(state);
122            }
123        }
124        for frame in &array.frames {
125            frame.array_hash(state, precision);
126        }
127        array.dtype.hash(state);
128        array.unsliced_validity.array_hash(state, precision);
129        array.unsliced_n_rows.hash(state);
130        array.slice_start.hash(state);
131        array.slice_stop.hash(state);
132    }
133
134    fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
135        if !match (&array.dictionary, &other.dictionary) {
136            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
137            (None, None) => true,
138            _ => false,
139        } {
140            return false;
141        }
142        if array.frames.len() != other.frames.len() {
143            return false;
144        }
145        for (a, b) in array.frames.iter().zip(&other.frames) {
146            if !a.array_eq(b, precision) {
147                return false;
148            }
149        }
150        array.dtype == other.dtype
151            && array
152                .unsliced_validity
153                .array_eq(&other.unsliced_validity, precision)
154            && array.unsliced_n_rows == other.unsliced_n_rows
155            && array.slice_start == other.slice_start
156            && array.slice_stop == other.slice_stop
157    }
158
159    fn nbuffers(array: &ZstdArray) -> usize {
160        array.dictionary.is_some() as usize + array.frames.len()
161    }
162
163    fn buffer(array: &ZstdArray, idx: usize) -> BufferHandle {
164        if let Some(dict) = &array.dictionary {
165            if idx == 0 {
166                return BufferHandle::new_host(dict.clone());
167            }
168            BufferHandle::new_host(array.frames[idx - 1].clone())
169        } else {
170            BufferHandle::new_host(array.frames[idx].clone())
171        }
172    }
173
174    fn buffer_name(array: &ZstdArray, idx: usize) -> Option<String> {
175        if array.dictionary.is_some() {
176            if idx == 0 {
177                Some("dictionary".to_string())
178            } else {
179                Some(format!("frame_{}", idx - 1))
180            }
181        } else {
182            Some(format!("frame_{idx}"))
183        }
184    }
185
186    fn nchildren(array: &ZstdArray) -> usize {
187        validity_nchildren(&array.unsliced_validity)
188    }
189
190    fn child(array: &ZstdArray, idx: usize) -> ArrayRef {
191        validity_to_child(&array.unsliced_validity, array.unsliced_n_rows)
192            .unwrap_or_else(|| vortex_panic!("ZstdArray child index {idx} out of bounds"))
193    }
194
195    fn child_name(_array: &ZstdArray, idx: usize) -> String {
196        match idx {
197            0 => "validity".to_string(),
198            _ => vortex_panic!("ZstdArray child_name index {idx} out of bounds"),
199        }
200    }
201
202    fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
203        Ok(ProstMetadata(array.metadata.clone()))
204    }
205
206    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
207        Ok(Some(metadata.0.encode_to_vec()))
208    }
209
210    fn deserialize(
211        bytes: &[u8],
212        _dtype: &DType,
213        _len: usize,
214        _buffers: &[BufferHandle],
215        _session: &VortexSession,
216    ) -> VortexResult<Self::Metadata> {
217        Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
218    }
219
220    fn build(
221        dtype: &DType,
222        len: usize,
223        metadata: &Self::Metadata,
224        buffers: &[BufferHandle],
225        children: &dyn ArrayChildren,
226    ) -> VortexResult<ZstdArray> {
227        let validity = if children.is_empty() {
228            Validity::from(dtype.nullability())
229        } else if children.len() == 1 {
230            let validity = children.get(0, &Validity::DTYPE, len)?;
231            Validity::Array(validity)
232        } else {
233            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
234        };
235
236        let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
237            // no dictionary
238            (
239                None,
240                buffers
241                    .iter()
242                    .map(|b| b.clone().try_to_host_sync())
243                    .collect::<VortexResult<Vec<_>>>()?,
244            )
245        } else {
246            // with dictionary
247            (
248                Some(buffers[0].clone().try_to_host_sync()?),
249                buffers[1..]
250                    .iter()
251                    .map(|b| b.clone().try_to_host_sync())
252                    .collect::<VortexResult<Vec<_>>>()?,
253            )
254        };
255
256        Ok(ZstdArray::new(
257            dictionary_buffer,
258            compressed_buffers,
259            dtype.clone(),
260            metadata.0.clone(),
261            len,
262            validity,
263        ))
264    }
265
266    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
267        vortex_ensure!(
268            children.len() <= 1,
269            "ZstdArray expects at most 1 child (validity), got {}",
270            children.len()
271        );
272
273        array.unsliced_validity = if children.is_empty() {
274            Validity::from(array.dtype.nullability())
275        } else {
276            Validity::Array(children.into_iter().next().vortex_expect("checked"))
277        };
278
279        Ok(())
280    }
281
282    fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
283        array
284            .decompress(ctx)?
285            .execute::<ArrayRef>(ctx)
286            .map(ExecutionResult::done)
287    }
288
289    fn reduce_parent(
290        array: &Self::Array,
291        parent: &ArrayRef,
292        child_idx: usize,
293    ) -> VortexResult<Option<ArrayRef>> {
294        crate::rules::RULES.evaluate(array, parent, child_idx)
295    }
296}
297
298#[derive(Clone, Debug)]
299pub struct Zstd;
300
301impl Zstd {
302    pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
303}
304
305#[derive(Clone, Debug)]
306pub struct ZstdArray {
307    pub(crate) dictionary: Option<ByteBuffer>,
308    pub(crate) frames: Vec<ByteBuffer>,
309    pub(crate) metadata: ZstdMetadata,
310    dtype: DType,
311    pub(crate) unsliced_validity: Validity,
312    unsliced_n_rows: usize,
313    stats_set: ArrayStats,
314    slice_start: usize,
315    slice_stop: usize,
316}
317
318/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
319#[derive(Debug)]
320pub struct ZstdArrayParts {
321    /// The optional dictionary used for compression.
322    pub dictionary: Option<ByteBuffer>,
323    /// The compressed frames.
324    pub frames: Vec<ByteBuffer>,
325    /// The compression metadata.
326    pub metadata: ZstdMetadata,
327    /// The data type of the uncompressed array.
328    pub dtype: DType,
329    /// The validity of the uncompressed array.
330    pub validity: Validity,
331    /// The number of rows in the uncompressed array.
332    pub n_rows: usize,
333    /// Slice start offset.
334    pub slice_start: usize,
335    /// Slice stop offset.
336    pub slice_stop: usize,
337}
338
339struct Frames {
340    dictionary: Option<ByteBuffer>,
341    frames: Vec<ByteBuffer>,
342    frame_metas: Vec<ZstdFrameMetadata>,
343}
344
345fn choose_max_dict_size(uncompressed_size: usize) -> usize {
346    // following recommendations from
347    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
348    // that is, 1/100 the data size, up to 100kB.
349    // It appears that zstd can't train dictionaries with <256 bytes.
350    (uncompressed_size / 100).clamp(256, 100 * 1024)
351}
352
353fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
354    let mask = parray.validity_mask()?;
355    Ok(parray.clone().into_array().filter(mask)?.to_primitive())
356}
357
358fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
359    let mask = vbv.validity_mask()?;
360    let buffer_and_value_byte_indices = match mask.bit_buffer() {
361        AllOr::None => (Buffer::empty(), Vec::new()),
362        _ => {
363            let mut buffer = BufferMut::with_capacity(
364                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
365                    + mask.true_count() * size_of::<ViewLen>(),
366            );
367            let mut value_byte_indices = Vec::new();
368            vbv.with_iterator(|iterator| {
369                // by flattening, we should omit nulls
370                for value in iterator.flatten() {
371                    value_byte_indices.push(buffer.len());
372                    // here's where we write the string lengths
373                    buffer
374                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
375                    buffer.extend_from_slice(value);
376                }
377                Ok::<_, VortexError>(())
378            })?;
379            (buffer.freeze(), value_byte_indices)
380        }
381    };
382    Ok(buffer_and_value_byte_indices)
383}
384
385/// Reconstruct BinaryView structs from length-prefixed byte data.
386///
387/// The buffer contains interleaved u32 lengths (little-endian) and string data.
388/// When the cumulative data exceeds `max_buffer_len`, the buffer is split (zero-copy) into
389/// multiple segments so that BinaryView's u32 offsets can address all data.
390///
391/// Pass [`MAX_BUFFER_LEN`] for `max_buffer_len` in production; a smaller value can be used in
392/// tests to exercise the splitting path without allocating >2 GiB.
393pub fn reconstruct_views(
394    buffer: &ByteBuffer,
395    max_buffer_len: usize,
396) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
397    let mut views = BufferMut::<BinaryView>::empty();
398    let mut buffers = Vec::new();
399    let mut segment_start: usize = 0;
400    let mut offset = 0;
401
402    while offset < buffer.len() {
403        let str_len = ViewLen::from_le_bytes(
404            buffer
405                .get(offset..offset + size_of::<ViewLen>())
406                .vortex_expect("corrupted zstd length")
407                .try_into()
408                .ok()
409                .vortex_expect("must fit ViewLen size"),
410        ) as usize;
411
412        let value_data_offset = offset + size_of::<ViewLen>();
413        let local_offset = value_data_offset - segment_start;
414
415        if local_offset + str_len > max_buffer_len && offset > segment_start {
416            buffers.push(buffer.slice(segment_start..offset));
417            segment_start = offset;
418        }
419
420        let local_offset = u32::try_from(value_data_offset - segment_start)
421            .vortex_expect("local offset within segment must fit in u32");
422        let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
423        let value = &buffer[value_data_offset..value_data_offset + str_len];
424        views.push(BinaryView::make_view(value, buf_index, local_offset));
425        offset = value_data_offset + str_len;
426    }
427
428    if segment_start < buffer.len() {
429        buffers.push(buffer.slice(segment_start..buffer.len()));
430    }
431
432    (buffers, views.freeze())
433}
434
435impl ZstdArray {
436    pub fn new(
437        dictionary: Option<ByteBuffer>,
438        frames: Vec<ByteBuffer>,
439        dtype: DType,
440        metadata: ZstdMetadata,
441        n_rows: usize,
442        validity: Validity,
443    ) -> Self {
444        Self {
445            dictionary,
446            frames,
447            metadata,
448            dtype,
449            unsliced_validity: validity,
450            unsliced_n_rows: n_rows,
451            stats_set: Default::default(),
452            slice_start: 0,
453            slice_stop: n_rows,
454        }
455    }
456
457    fn compress_values(
458        value_bytes: &ByteBuffer,
459        frame_byte_starts: &[usize],
460        level: i32,
461        values_per_frame: usize,
462        n_values: usize,
463        use_dictionary: bool,
464    ) -> VortexResult<Frames> {
465        let n_frames = frame_byte_starts.len();
466
467        // Would-be sample sizes if we end up applying zstd dictionary
468        let mut sample_sizes = Vec::with_capacity(n_frames);
469        for i in 0..n_frames {
470            let frame_byte_end = frame_byte_starts
471                .get(i + 1)
472                .copied()
473                .unwrap_or(value_bytes.len());
474            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
475        }
476        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
477
478        let (dictionary, mut compressor) = if !use_dictionary
479            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
480        {
481            // no dictionary
482            (None, zstd::bulk::Compressor::new(level)?)
483        } else {
484            // with dictionary
485            let max_dict_size = choose_max_dict_size(value_bytes.len());
486            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
487                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
488
489            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
490            (Some(ByteBuffer::from(dict)), compressor)
491        };
492
493        let mut frame_metas = vec![];
494        let mut frames = vec![];
495        for i in 0..n_frames {
496            let frame_byte_end = frame_byte_starts
497                .get(i + 1)
498                .copied()
499                .unwrap_or(value_bytes.len());
500
501            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
502            let compressed = compressor
503                .compress(uncompressed)
504                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
505            frame_metas.push(ZstdFrameMetadata {
506                uncompressed_size: uncompressed.len() as u64,
507                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
508            });
509            frames.push(ByteBuffer::from(compressed));
510        }
511
512        Ok(Frames {
513            dictionary,
514            frames,
515            frame_metas,
516        })
517    }
518
519    /// Creates a ZstdArray from a primitive array.
520    ///
521    /// # Arguments
522    /// * `parray` - The primitive array to compress
523    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
524    /// * `values_per_frame` - Number of values per frame (0 = single frame)
525    pub fn from_primitive(
526        parray: &PrimitiveArray,
527        level: i32,
528        values_per_frame: usize,
529    ) -> VortexResult<Self> {
530        Self::from_primitive_impl(parray, level, values_per_frame, true)
531    }
532
533    /// Creates a ZstdArray from a primitive array without using a dictionary.
534    ///
535    /// This is useful when the compressed data will be decompressed by systems
536    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
537    ///
538    /// Note: Without a dictionary, each frame is compressed independently.
539    /// Dictionaries are trained from sample data from previously seen frames,
540    /// to improve compression ratio.
541    ///
542    /// # Arguments
543    /// * `parray` - The primitive array to compress
544    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
545    /// * `values_per_frame` - Number of values per frame (0 = single frame)
546    pub fn from_primitive_without_dict(
547        parray: &PrimitiveArray,
548        level: i32,
549        values_per_frame: usize,
550    ) -> VortexResult<Self> {
551        Self::from_primitive_impl(parray, level, values_per_frame, false)
552    }
553
554    fn from_primitive_impl(
555        parray: &PrimitiveArray,
556        level: i32,
557        values_per_frame: usize,
558        use_dictionary: bool,
559    ) -> VortexResult<Self> {
560        let dtype = parray.dtype().clone();
561        let byte_width = parray.ptype().byte_width();
562
563        // We compress only the valid elements.
564        let values = collect_valid_primitive(parray)?;
565        let n_values = values.len();
566        let values_per_frame = if values_per_frame > 0 {
567            values_per_frame
568        } else {
569            n_values
570        };
571
572        let value_bytes = values.buffer_handle().try_to_host_sync()?;
573        // Align frames to buffer alignment. This is necessary for overaligned buffers.
574        let alignment = *value_bytes.alignment();
575        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
576
577        let frame_byte_starts = (0..n_values * byte_width)
578            .step_by(step_width)
579            .collect::<Vec<_>>();
580        let Frames {
581            dictionary,
582            frames,
583            frame_metas,
584        } = Self::compress_values(
585            &value_bytes,
586            &frame_byte_starts,
587            level,
588            values_per_frame,
589            n_values,
590            use_dictionary,
591        )?;
592
593        let metadata = ZstdMetadata {
594            dictionary_size: dictionary
595                .as_ref()
596                .map_or(0, |dict| dict.len())
597                .try_into()?,
598            frames: frame_metas,
599        };
600
601        Ok(ZstdArray::new(
602            dictionary,
603            frames,
604            dtype,
605            metadata,
606            parray.len(),
607            parray.validity().clone(),
608        ))
609    }
610
611    /// Creates a ZstdArray from a VarBinView array.
612    ///
613    /// # Arguments
614    /// * `vbv` - The VarBinView array to compress
615    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
616    /// * `values_per_frame` - Number of values per frame (0 = single frame)
617    pub fn from_var_bin_view(
618        vbv: &VarBinViewArray,
619        level: i32,
620        values_per_frame: usize,
621    ) -> VortexResult<Self> {
622        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
623    }
624
625    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
626    ///
627    /// This is useful when the compressed data will be decompressed by systems
628    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
629    ///
630    /// Note: Without a dictionary, each frame is compressed independently.
631    /// Dictionaries are trained from sample data from previously seen frames,
632    /// to improve compression ratio.
633    ///
634    /// # Arguments
635    /// * `vbv` - The VarBinView array to compress
636    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
637    /// * `values_per_frame` - Number of values per frame (0 = single frame)
638    pub fn from_var_bin_view_without_dict(
639        vbv: &VarBinViewArray,
640        level: i32,
641        values_per_frame: usize,
642    ) -> VortexResult<Self> {
643        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
644    }
645
646    fn from_var_bin_view_impl(
647        vbv: &VarBinViewArray,
648        level: i32,
649        values_per_frame: usize,
650        use_dictionary: bool,
651    ) -> VortexResult<Self> {
652        // Approach for strings: we prefix each string with its length as a u32.
653        // This is the same as what Parquet does. In some cases it may be better
654        // to separate the binary data and lengths as two separate streams, but
655        // this approach is simpler and can be best in cases when there is
656        // mutual information between strings and their lengths.
657        let dtype = vbv.dtype().clone();
658
659        // We compress only the valid elements.
660        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
661        let n_values = value_byte_indices.len();
662        let values_per_frame = if values_per_frame > 0 {
663            values_per_frame
664        } else {
665            n_values
666        };
667
668        let frame_byte_starts = (0..n_values)
669            .step_by(values_per_frame)
670            .map(|i| value_byte_indices[i])
671            .collect::<Vec<_>>();
672        let Frames {
673            dictionary,
674            frames,
675            frame_metas,
676        } = Self::compress_values(
677            &value_bytes,
678            &frame_byte_starts,
679            level,
680            values_per_frame,
681            n_values,
682            use_dictionary,
683        )?;
684
685        let metadata = ZstdMetadata {
686            dictionary_size: dictionary
687                .as_ref()
688                .map_or(0, |dict| dict.len())
689                .try_into()?,
690            frames: frame_metas,
691        };
692        Ok(ZstdArray::new(
693            dictionary,
694            frames,
695            dtype,
696            metadata,
697            vbv.len(),
698            vbv.validity().clone(),
699        ))
700    }
701
702    pub fn from_canonical(
703        canonical: &Canonical,
704        level: i32,
705        values_per_frame: usize,
706    ) -> VortexResult<Option<Self>> {
707        match canonical {
708            Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
709                parray,
710                level,
711                values_per_frame,
712            )?)),
713            Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
714                vbv,
715                level,
716                values_per_frame,
717            )?)),
718            _ => Ok(None),
719        }
720    }
721
722    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
723        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
724            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
725    }
726
727    fn byte_width(&self) -> usize {
728        if self.dtype.is_primitive() {
729            self.dtype.as_ptype().byte_width()
730        } else {
731            1
732        }
733    }
734
735    pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
736        // To start, we figure out which frames we need to decompress, and with
737        // what row offset into the first such frame.
738        let byte_width = self.byte_width();
739        let slice_n_rows = self.slice_stop - self.slice_start;
740        let slice_value_indices = self
741            .unsliced_validity
742            .execute_mask(self.unsliced_n_rows, ctx)?
743            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
744
745        let slice_value_idx_start = slice_value_indices[0];
746        let slice_value_idx_stop = slice_value_indices[1];
747
748        let mut frames_to_decompress = vec![];
749        let mut value_idx_start = 0;
750        let mut uncompressed_size_to_decompress = 0;
751        let mut n_skipped_values = 0;
752        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
753            if value_idx_start >= slice_value_idx_stop {
754                break;
755            }
756
757            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
758                .vortex_expect("Uncompressed size must fit in usize");
759            let frame_n_values = if frame_meta.n_values == 0 {
760                // possibly older primitive-only metadata that just didn't store this
761                frame_uncompressed_size / byte_width
762            } else {
763                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
764            };
765
766            let value_idx_stop = value_idx_start + frame_n_values;
767            if value_idx_stop > slice_value_idx_start {
768                // we need this frame
769                frames_to_decompress.push(frame);
770                uncompressed_size_to_decompress += frame_uncompressed_size;
771            } else {
772                n_skipped_values += frame_n_values;
773            }
774            value_idx_start = value_idx_stop;
775        }
776
777        // then we actually decompress those frames
778        let mut decompressor = if let Some(dictionary) = &self.dictionary {
779            zstd::bulk::Decompressor::with_dictionary(dictionary)?
780        } else {
781            zstd::bulk::Decompressor::new()?
782        };
783        let mut decompressed = ByteBufferMut::with_capacity_aligned(
784            uncompressed_size_to_decompress,
785            Alignment::new(byte_width),
786        );
787        unsafe {
788            // safety: we immediately fill all bytes in the following loop,
789            // assuming our metadata's uncompressed size is correct
790            decompressed.set_len(uncompressed_size_to_decompress);
791        }
792        let mut uncompressed_start = 0;
793        for frame in frames_to_decompress {
794            let uncompressed_written = decompressor
795                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
796            uncompressed_start += uncompressed_written;
797        }
798        if uncompressed_start != uncompressed_size_to_decompress {
799            vortex_panic!(
800                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
801                uncompressed_size_to_decompress,
802                uncompressed_start
803            );
804        }
805
806        let decompressed = decompressed.freeze();
807        // Last, we slice the exact values requested out of the decompressed data.
808        let mut slice_validity = self
809            .unsliced_validity
810            .slice(self.slice_start..self.slice_stop)?;
811
812        // NOTE: this block handles setting the output type when the validity and DType disagree.
813        //
814        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
815        // the data frames. A ZSTD Array that was initialized must always hold onto its full
816        // validity bitmap, even if sliced to only include non-null values.
817        //
818        // We ensure that the validity of the decompressed array ALWAYS matches the validity
819        // implied by the DType.
820        if !self.dtype().is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
821            assert!(
822                matches!(slice_validity, Validity::AllValid),
823                "ZSTD array expects to be non-nullable but there are nulls after decompression"
824            );
825
826            slice_validity = Validity::NonNullable;
827        } else if self.dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
828            slice_validity = Validity::AllValid;
829        }
830        //
831        // END OF IMPORTANT BLOCK
832        //
833
834        match &self.dtype {
835            DType::Primitive(..) => {
836                let slice_values_buffer = decompressed.slice(
837                    (slice_value_idx_start - n_skipped_values) * byte_width
838                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
839                );
840                let primitive = PrimitiveArray::from_values_byte_buffer(
841                    slice_values_buffer,
842                    self.dtype.as_ptype(),
843                    slice_validity,
844                    slice_n_rows,
845                );
846
847                Ok(primitive.into_array())
848            }
849            DType::Binary(_) | DType::Utf8(_) => {
850                match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
851                    AllOr::All => {
852                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
853                        let valid_views = all_views.slice(
854                            slice_value_idx_start - n_skipped_values
855                                ..slice_value_idx_stop - n_skipped_values,
856                        );
857
858                        // SAFETY: we properly construct the views inside `reconstruct_views`
859                        Ok(unsafe {
860                            VarBinViewArray::new_unchecked(
861                                valid_views,
862                                Arc::from(buffers),
863                                self.dtype.clone(),
864                                slice_validity,
865                            )
866                        }
867                        .into_array())
868                    }
869                    AllOr::None => Ok(ConstantArray::new(
870                        Scalar::null(self.dtype.clone()),
871                        slice_n_rows,
872                    )
873                    .into_array()),
874                    AllOr::Some(valid_indices) => {
875                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
876                        let valid_views = all_views.slice(
877                            slice_value_idx_start - n_skipped_values
878                                ..slice_value_idx_stop - n_skipped_values,
879                        );
880
881                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
882                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
883                            views[*index] = view
884                        }
885
886                        // SAFETY: we properly construct the views inside `reconstruct_views`
887                        Ok(unsafe {
888                            VarBinViewArray::new_unchecked(
889                                views.freeze(),
890                                Arc::from(buffers),
891                                self.dtype.clone(),
892                                slice_validity,
893                            )
894                        }
895                        .into_array())
896                    }
897                }
898            }
899            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
900        }
901    }
902
903    pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
904        let new_start = self.slice_start + start;
905        let new_stop = self.slice_start + stop;
906
907        assert!(
908            new_start <= self.slice_stop,
909            "new slice start {new_start} exceeds end {}",
910            self.slice_stop
911        );
912
913        assert!(
914            new_stop <= self.slice_stop,
915            "new slice stop {new_stop} exceeds end {}",
916            self.slice_stop
917        );
918
919        ZstdArray {
920            slice_start: self.slice_start + start,
921            slice_stop: self.slice_start + stop,
922            stats_set: Default::default(),
923            ..self.clone()
924        }
925    }
926
927    /// Consumes the array and returns its parts.
928    pub fn into_parts(self) -> ZstdArrayParts {
929        ZstdArrayParts {
930            dictionary: self.dictionary,
931            frames: self.frames,
932            metadata: self.metadata,
933            dtype: self.dtype,
934            validity: self.unsliced_validity,
935            n_rows: self.unsliced_n_rows,
936            slice_start: self.slice_start,
937            slice_stop: self.slice_stop,
938        }
939    }
940
941    pub(crate) fn dtype(&self) -> &DType {
942        &self.dtype
943    }
944
945    pub(crate) fn slice_start(&self) -> usize {
946        self.slice_start
947    }
948
949    pub(crate) fn slice_stop(&self) -> usize {
950        self.slice_stop
951    }
952
953    pub(crate) fn unsliced_n_rows(&self) -> usize {
954        self.unsliced_n_rows
955    }
956}
957
958impl ValiditySliceHelper for ZstdArray {
959    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
960        (&self.unsliced_validity, self.slice_start, self.slice_stop)
961    }
962}
963
964impl OperationsVTable<Zstd> for Zstd {
965    fn scalar_at(array: &ZstdArray, index: usize) -> VortexResult<Scalar> {
966        let mut ctx = LEGACY_SESSION.create_execution_ctx();
967        array
968            ._slice(index, index + 1)
969            .decompress(&mut ctx)?
970            .scalar_at(0)
971    }
972}
973
974#[cfg(test)]
975#[allow(clippy::cast_possible_truncation)]
976mod tests {
977    use vortex_buffer::ByteBuffer;
978
979    use super::reconstruct_views;
980    use crate::array::BinaryView;
981
982    /// Build a Zstd-style interleaved buffer: [u32-LE length][string bytes] repeated.
983    fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
984        let mut buf = Vec::new();
985        for s in strings {
986            let len = s.len() as u32;
987            buf.extend_from_slice(&len.to_le_bytes());
988            buf.extend_from_slice(s);
989        }
990        ByteBuffer::copy_from(buf.as_slice())
991    }
992
993    #[test]
994    fn test_reconstruct_views_no_split() {
995        let strings: &[&[u8]] = &[b"hello", b"world"];
996        let buf = make_interleaved(strings);
997        let (buffers, views) = reconstruct_views(&buf, 1024);
998
999        assert_eq!(buffers.len(), 1);
1000        assert_eq!(views.len(), 2);
1001        // Each entry: [u32 len (4 bytes)][data], so offsets are 4 and 4+5+4=13
1002        assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1003        assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1004    }
1005
1006    #[test]
1007    fn test_reconstruct_views_split_across_segments() {
1008        // "aaaaaaaaaaaaa" (13 bytes) and "bbbbbbbbbbbbb" (13 bytes).
1009        // Each entry occupies 4 (length prefix) + 13 (data) = 17 bytes.
1010        // With max_buffer_len=20, the second entry's data (offset 4+13+4=21) exceeds the limit,
1011        // so it rolls into a second segment.
1012        let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1013        let buf = make_interleaved(strings);
1014        let (buffers, views) = reconstruct_views(&buf, 20);
1015
1016        assert_eq!(buffers.len(), 2);
1017        assert_eq!(views.len(), 2);
1018        assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1019        // Second entry starts a new segment at byte 17 (the length prefix), so local offset = 4.
1020        assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1021    }
1022}