Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use itertools::Itertools as _;
12use prost::Message as _;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayId;
17use vortex_array::ArrayParts;
18use vortex_array::ArrayRef;
19use vortex_array::ArrayView;
20use vortex_array::Canonical;
21use vortex_array::ExecutionCtx;
22use vortex_array::ExecutionResult;
23use vortex_array::IntoArray;
24use vortex_array::Precision;
25use vortex_array::accessor::ArrayAccessor;
26use vortex_array::arrays::ConstantArray;
27use vortex_array::arrays::PrimitiveArray;
28use vortex_array::arrays::VarBinViewArray;
29use vortex_array::arrays::varbinview::build_views::BinaryView;
30use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
31use vortex_array::buffer::BufferHandle;
32use vortex_array::dtype::DType;
33use vortex_array::scalar::Scalar;
34use vortex_array::serde::ArrayChildren;
35use vortex_array::validity::Validity;
36use vortex_array::vtable::OperationsVTable;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityVTable;
39use vortex_array::vtable::child_to_validity;
40use vortex_array::vtable::validity_to_child;
41use vortex_buffer::Alignment;
42use vortex_buffer::Buffer;
43use vortex_buffer::BufferMut;
44use vortex_buffer::ByteBuffer;
45use vortex_buffer::ByteBufferMut;
46use vortex_error::VortexError;
47use vortex_error::VortexExpect;
48use vortex_error::VortexResult;
49use vortex_error::vortex_bail;
50use vortex_error::vortex_ensure;
51use vortex_error::vortex_err;
52use vortex_error::vortex_panic;
53use vortex_mask::AllOr;
54use vortex_session::VortexSession;
55use vortex_session::registry::CachedId;
56
57use crate::ZstdFrameMetadata;
58use crate::ZstdMetadata;
59
60// Zstd doesn't support training dictionaries on very few samples.
61const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
62type ViewLen = u32;
63
64// Overall approach here:
65// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
66// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
67// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
68// want somewhat faster access to slices or individual rows, allowing us to only
69// decompress the necessary frames.
70
71// Visually, during decompression, we have an interval of frames we're
72// decompressing and a tighter interval of the slice we actually care about.
73// |=============values (all valid elements)==============|
74// |<-skipped_uncompressed->|----decompressed-------------|
75//                              |------slice-------|
76//                              ^                  ^
77// |<-slice_uncompressed_start->|                  |
78// |<------------slice_uncompressed_stop---------->|
79// We then insert these values to the correct position using a primitive array
80// constructor.
81
82/// A [`Zstd`]-encoded Vortex array.
83pub type ZstdArray = Array<Zstd>;
84
85impl ArrayHash for ZstdData {
86    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
87        match &self.dictionary {
88            Some(dict) => {
89                true.hash(state);
90                dict.array_hash(state, precision);
91            }
92            None => {
93                false.hash(state);
94            }
95        }
96        for frame in &self.frames {
97            frame.array_hash(state, precision);
98        }
99        self.unsliced_n_rows.hash(state);
100        self.slice_start.hash(state);
101        self.slice_stop.hash(state);
102    }
103}
104
105impl ArrayEq for ZstdData {
106    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
107        if !match (&self.dictionary, &other.dictionary) {
108            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
109            (None, None) => true,
110            _ => false,
111        } {
112            return false;
113        }
114        if self.frames.len() != other.frames.len() {
115            return false;
116        }
117        for (a, b) in self.frames.iter().zip(&other.frames) {
118            if !a.array_eq(b, precision) {
119                return false;
120            }
121        }
122        self.unsliced_n_rows == other.unsliced_n_rows
123            && self.slice_start == other.slice_start
124            && self.slice_stop == other.slice_stop
125    }
126}
127
128impl VTable for Zstd {
129    type ArrayData = ZstdData;
130
131    type OperationsVTable = Self;
132    type ValidityVTable = Self;
133
134    fn id(&self) -> ArrayId {
135        static ID: CachedId = CachedId::new("vortex.zstd");
136        *ID
137    }
138
139    fn validate(
140        &self,
141        data: &Self::ArrayData,
142        dtype: &DType,
143        len: usize,
144        slots: &[Option<ArrayRef>],
145    ) -> VortexResult<()> {
146        let validity = child_to_validity(slots[0].as_ref(), dtype.nullability());
147        data.validate(dtype, len, &validity)
148    }
149
150    fn nbuffers(array: ArrayView<'_, Self>) -> usize {
151        array.dictionary.is_some() as usize + array.frames.len()
152    }
153
154    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
155        if let Some(dict) = &array.dictionary {
156            if idx == 0 {
157                return BufferHandle::new_host(dict.clone());
158            }
159            BufferHandle::new_host(array.frames[idx - 1].clone())
160        } else {
161            BufferHandle::new_host(array.frames[idx].clone())
162        }
163    }
164
165    fn buffer_name(array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
166        if array.dictionary.is_some() {
167            if idx == 0 {
168                Some("dictionary".to_string())
169            } else {
170                Some(format!("frame_{}", idx - 1))
171            }
172        } else {
173            Some(format!("frame_{idx}"))
174        }
175    }
176
177    fn serialize(
178        array: ArrayView<'_, Self>,
179        _session: &VortexSession,
180    ) -> VortexResult<Option<Vec<u8>>> {
181        Ok(Some(array.metadata.clone().encode_to_vec()))
182    }
183
184    fn deserialize(
185        &self,
186        dtype: &DType,
187        len: usize,
188        metadata: &[u8],
189        buffers: &[BufferHandle],
190        children: &dyn ArrayChildren,
191        _session: &VortexSession,
192    ) -> VortexResult<ArrayParts<Self>> {
193        let metadata = ZstdMetadata::decode(metadata)?;
194        let validity = if children.is_empty() {
195            Validity::from(dtype.nullability())
196        } else if children.len() == 1 {
197            let validity = children.get(0, &Validity::DTYPE, len)?;
198            Validity::Array(validity)
199        } else {
200            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
201        };
202
203        let (dictionary_buffer, compressed_buffers) = if metadata.dictionary_size == 0 {
204            // no dictionary
205            (
206                None,
207                buffers
208                    .iter()
209                    .map(|b| b.clone().try_to_host_sync())
210                    .collect::<VortexResult<Vec<_>>>()?,
211            )
212        } else {
213            // with dictionary
214            (
215                Some(buffers[0].clone().try_to_host_sync()?),
216                buffers[1..]
217                    .iter()
218                    .map(|b| b.clone().try_to_host_sync())
219                    .collect::<VortexResult<Vec<_>>>()?,
220            )
221        };
222
223        let slots = vec![validity_to_child(&validity, len)];
224        let data = ZstdData::new(dictionary_buffer, compressed_buffers, metadata, len);
225        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
226    }
227
228    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
229        SLOT_NAMES[idx].to_string()
230    }
231
232    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
233        let unsliced_validity = child_to_validity(
234            array.as_ref().slots()[0].as_ref(),
235            array.dtype().nullability(),
236        );
237        array
238            .data()
239            .decompress(array.dtype(), &unsliced_validity, ctx)?
240            .execute::<ArrayRef>(ctx)
241            .map(ExecutionResult::done)
242    }
243
244    fn reduce_parent(
245        array: ArrayView<'_, Self>,
246        parent: &ArrayRef,
247        child_idx: usize,
248    ) -> VortexResult<Option<ArrayRef>> {
249        crate::rules::RULES.evaluate(array, parent, child_idx)
250    }
251}
252
253#[derive(Clone, Debug)]
254pub struct Zstd;
255
256impl Zstd {
257    pub fn try_new(dtype: DType, data: ZstdData, validity: Validity) -> VortexResult<ZstdArray> {
258        let len = data.len();
259        data.validate(&dtype, len, &validity)?;
260        let slots = vec![validity_to_child(&validity, data.unsliced_n_rows())];
261        Ok(unsafe {
262            Array::from_parts_unchecked(ArrayParts::new(Zstd, dtype, len, data).with_slots(slots))
263        })
264    }
265
266    /// Compress a [`VarBinViewArray`] using Zstd without a dictionary.
267    pub fn from_var_bin_view_without_dict(
268        vbv: &VarBinViewArray,
269        level: i32,
270        values_per_frame: usize,
271        ctx: &mut ExecutionCtx,
272    ) -> VortexResult<ZstdArray> {
273        let validity = vbv.validity()?;
274        Self::try_new(
275            vbv.dtype().clone(),
276            ZstdData::from_var_bin_view_without_dict(vbv, level, values_per_frame, ctx)?,
277            validity,
278        )
279    }
280
281    /// Compress a [`PrimitiveArray`] using Zstd.
282    pub fn from_primitive(
283        parray: &PrimitiveArray,
284        level: i32,
285        values_per_frame: usize,
286        ctx: &mut ExecutionCtx,
287    ) -> VortexResult<ZstdArray> {
288        let validity = parray.validity()?;
289        Self::try_new(
290            parray.dtype().clone(),
291            ZstdData::from_primitive(parray, level, values_per_frame, ctx)?,
292            validity,
293        )
294    }
295
296    /// Compress a [`VarBinViewArray`] using Zstd.
297    pub fn from_var_bin_view(
298        vbv: &VarBinViewArray,
299        level: i32,
300        values_per_frame: usize,
301        ctx: &mut ExecutionCtx,
302    ) -> VortexResult<ZstdArray> {
303        let validity = vbv.validity()?;
304        Self::try_new(
305            vbv.dtype().clone(),
306            ZstdData::from_var_bin_view(vbv, level, values_per_frame, ctx)?,
307            validity,
308        )
309    }
310
311    pub fn decompress(array: &ZstdArray, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
312        let unsliced_validity = child_to_validity(
313            array.as_ref().slots()[0].as_ref(),
314            array.dtype().nullability(),
315        );
316        array
317            .data()
318            .decompress(array.dtype(), &unsliced_validity, ctx)
319    }
320}
321
322/// The validity bitmap indicating which elements are non-null.
323pub(super) const NUM_SLOTS: usize = 1;
324pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
325
326#[derive(Clone, Debug)]
327pub struct ZstdData {
328    pub(crate) dictionary: Option<ByteBuffer>,
329    pub(crate) frames: Vec<ByteBuffer>,
330    pub(crate) metadata: ZstdMetadata,
331    unsliced_n_rows: usize,
332    slice_start: usize,
333    slice_stop: usize,
334}
335
336impl Display for ZstdData {
337    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
338        write!(
339            f,
340            "nrows: {}, slice: {}..{}",
341            self.unsliced_n_rows, self.slice_start, self.slice_stop
342        )
343    }
344}
345
346pub struct ZstdDataParts {
347    pub dictionary: Option<ByteBuffer>,
348    pub frames: Vec<ByteBuffer>,
349    pub metadata: ZstdMetadata,
350    pub validity: Validity,
351    pub n_rows: usize,
352    pub slice_start: usize,
353    pub slice_stop: usize,
354}
355
356/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
357#[derive(Debug)]
358struct Frames {
359    dictionary: Option<ByteBuffer>,
360    frames: Vec<ByteBuffer>,
361    frame_metas: Vec<ZstdFrameMetadata>,
362}
363
364fn choose_max_dict_size(uncompressed_size: usize) -> usize {
365    // following recommendations from
366    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
367    // that is, 1/100 the data size, up to 100kB.
368    // It appears that zstd can't train dictionaries with <256 bytes.
369    (uncompressed_size / 100).clamp(256, 100 * 1024)
370}
371
372fn collect_valid_primitive(
373    parray: &PrimitiveArray,
374    ctx: &mut ExecutionCtx,
375) -> VortexResult<PrimitiveArray> {
376    let mask = parray
377        .as_ref()
378        .validity()?
379        .execute_mask(parray.as_ref().len(), ctx)?;
380    let result = parray.filter(mask)?.execute::<PrimitiveArray>(ctx)?;
381    Ok(result)
382}
383
384fn collect_valid_vbv(
385    vbv: &VarBinViewArray,
386    ctx: &mut ExecutionCtx,
387) -> VortexResult<(ByteBuffer, Vec<usize>)> {
388    let mask = vbv
389        .as_ref()
390        .validity()?
391        .execute_mask(vbv.as_ref().len(), ctx)?;
392    let buffer_and_value_byte_indices = match mask.bit_buffer() {
393        AllOr::None => (Buffer::empty(), Vec::new()),
394        _ => {
395            let mut buffer = BufferMut::with_capacity(
396                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
397                    + mask.true_count() * size_of::<ViewLen>(),
398            );
399            let mut value_byte_indices = Vec::new();
400            vbv.with_iterator(|iterator| {
401                // by flattening, we should omit nulls
402                for value in iterator.flatten() {
403                    value_byte_indices.push(buffer.len());
404                    // here's where we write the string lengths
405                    buffer
406                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
407                    buffer.extend_from_slice(value);
408                }
409                Ok::<_, VortexError>(())
410            })?;
411            (buffer.freeze(), value_byte_indices)
412        }
413    };
414    Ok(buffer_and_value_byte_indices)
415}
416
417/// Reconstruct BinaryView structs from length-prefixed byte data.
418///
419/// The buffer contains interleaved u32 lengths (little-endian) and string data.
420/// When the cumulative data exceeds `max_buffer_len`, the buffer is split (zero-copy) into
421/// multiple segments so that BinaryView's u32 offsets can address all data.
422///
423/// Pass [`MAX_BUFFER_LEN`] for `max_buffer_len` in production; a smaller value can be used in
424/// tests to exercise the splitting path without allocating >2 GiB.
425pub fn reconstruct_views(
426    buffer: &ByteBuffer,
427    max_buffer_len: usize,
428) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
429    let mut views = BufferMut::<BinaryView>::empty();
430    let mut buffers = Vec::new();
431    let mut segment_start: usize = 0;
432    let mut offset = 0;
433
434    while offset < buffer.len() {
435        let str_len = ViewLen::from_le_bytes(
436            buffer
437                .get(offset..offset + size_of::<ViewLen>())
438                .vortex_expect("corrupted zstd length")
439                .try_into()
440                .ok()
441                .vortex_expect("must fit ViewLen size"),
442        ) as usize;
443
444        let value_data_offset = offset + size_of::<ViewLen>();
445        let local_offset = value_data_offset - segment_start;
446
447        if local_offset + str_len > max_buffer_len && offset > segment_start {
448            buffers.push(buffer.slice(segment_start..offset));
449            segment_start = offset;
450        }
451
452        let local_offset = u32::try_from(value_data_offset - segment_start)
453            .vortex_expect("local offset within segment must fit in u32");
454        let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
455        let value = &buffer[value_data_offset..value_data_offset + str_len];
456        views.push(BinaryView::make_view(value, buf_index, local_offset));
457        offset = value_data_offset + str_len;
458    }
459
460    if segment_start < buffer.len() {
461        buffers.push(buffer.slice(segment_start..buffer.len()));
462    }
463
464    (buffers, views.freeze())
465}
466
467impl ZstdData {
468    pub fn new(
469        dictionary: Option<ByteBuffer>,
470        frames: Vec<ByteBuffer>,
471        metadata: ZstdMetadata,
472        n_rows: usize,
473    ) -> Self {
474        Self {
475            dictionary,
476            frames,
477            metadata,
478            unsliced_n_rows: n_rows,
479            slice_start: 0,
480            slice_stop: n_rows,
481        }
482    }
483
484    pub fn validate(&self, dtype: &DType, len: usize, validity: &Validity) -> VortexResult<()> {
485        vortex_ensure!(
486            matches!(
487                dtype,
488                DType::Primitive(..) | DType::Binary(_) | DType::Utf8(_)
489            ),
490            "Unsupported dtype for Zstd array: {dtype}"
491        );
492        vortex_ensure!(
493            self.slice_start <= self.slice_stop,
494            "Invalid slice range {}..{}",
495            self.slice_start,
496            self.slice_stop
497        );
498        vortex_ensure!(
499            self.slice_stop <= self.unsliced_n_rows,
500            "Slice stop {} exceeds unsliced row count {}",
501            self.slice_stop,
502            self.unsliced_n_rows
503        );
504        vortex_ensure!(
505            self.slice_stop - self.slice_start == len,
506            "Slice length {} does not match array length {}",
507            self.slice_stop - self.slice_start,
508            len
509        );
510        if let Some(validity_len) = validity.maybe_len() {
511            vortex_ensure!(
512                validity_len == self.unsliced_n_rows,
513                "Validity length {} does not match unsliced row count {}",
514                validity_len,
515                self.unsliced_n_rows
516            );
517        }
518
519        match &self.dictionary {
520            Some(dictionary) => vortex_ensure!(
521                usize::try_from(self.metadata.dictionary_size)? == dictionary.len(),
522                "Dictionary size metadata {} does not match buffer size {}",
523                self.metadata.dictionary_size,
524                dictionary.len()
525            ),
526            None => vortex_ensure!(
527                self.metadata.dictionary_size == 0,
528                "Dictionary metadata present without dictionary buffer"
529            ),
530        }
531        vortex_ensure!(
532            self.frames.len() == self.metadata.frames.len(),
533            "Frame count {} does not match metadata frame count {}",
534            self.frames.len(),
535            self.metadata.frames.len()
536        );
537
538        Ok(())
539    }
540
541    pub(crate) fn with_slice(&self, start: usize, stop: usize) -> Self {
542        let new_start = self.slice_start + start;
543        let new_stop = self.slice_start + stop;
544
545        assert!(
546            new_start <= self.slice_stop,
547            "new slice start {new_start} exceeds end {}",
548            self.slice_stop
549        );
550
551        assert!(
552            new_stop <= self.slice_stop,
553            "new slice stop {new_stop} exceeds end {}",
554            self.slice_stop
555        );
556
557        Self {
558            slice_start: new_start,
559            slice_stop: new_stop,
560            ..self.clone()
561        }
562    }
563
564    fn compress_values(
565        value_bytes: &ByteBuffer,
566        frame_byte_starts: &[usize],
567        level: i32,
568        values_per_frame: usize,
569        n_values: usize,
570        use_dictionary: bool,
571    ) -> VortexResult<Frames> {
572        let n_frames = frame_byte_starts.len();
573
574        // Would-be sample sizes if we end up applying zstd dictionary
575        let mut sample_sizes = Vec::with_capacity(n_frames);
576        for i in 0..n_frames {
577            let frame_byte_end = frame_byte_starts
578                .get(i + 1)
579                .copied()
580                .unwrap_or(value_bytes.len());
581            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
582        }
583        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
584
585        let (dictionary, mut compressor) = if !use_dictionary
586            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
587        {
588            // no dictionary
589            (None, zstd::bulk::Compressor::new(level)?)
590        } else {
591            // with dictionary
592            let max_dict_size = choose_max_dict_size(value_bytes.len());
593            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
594                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
595
596            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
597            (Some(ByteBuffer::from(dict)), compressor)
598        };
599
600        let mut frame_metas = vec![];
601        let mut frames = vec![];
602        for i in 0..n_frames {
603            let frame_byte_end = frame_byte_starts
604                .get(i + 1)
605                .copied()
606                .unwrap_or(value_bytes.len());
607
608            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
609            let compressed = compressor
610                .compress(uncompressed)
611                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
612            frame_metas.push(ZstdFrameMetadata {
613                uncompressed_size: uncompressed.len() as u64,
614                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
615            });
616            frames.push(ByteBuffer::from(compressed));
617        }
618
619        Ok(Frames {
620            dictionary,
621            frames,
622            frame_metas,
623        })
624    }
625
626    /// Creates a ZstdArray from a primitive array.
627    ///
628    /// # Arguments
629    /// * `parray` - The primitive array to compress
630    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
631    /// * `values_per_frame` - Number of values per frame (0 = single frame)
632    pub fn from_primitive(
633        parray: &PrimitiveArray,
634        level: i32,
635        values_per_frame: usize,
636        ctx: &mut ExecutionCtx,
637    ) -> VortexResult<Self> {
638        Self::from_primitive_impl(parray, level, values_per_frame, true, ctx)
639    }
640
641    /// Creates a ZstdArray from a primitive array without using a dictionary.
642    ///
643    /// This is useful when the compressed data will be decompressed by systems
644    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
645    ///
646    /// Note: Without a dictionary, each frame is compressed independently.
647    /// Dictionaries are trained from sample data from previously seen frames,
648    /// to improve compression ratio.
649    ///
650    /// # Arguments
651    /// * `parray` - The primitive array to compress
652    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
653    /// * `values_per_frame` - Number of values per frame (0 = single frame)
654    pub fn from_primitive_without_dict(
655        parray: &PrimitiveArray,
656        level: i32,
657        values_per_frame: usize,
658        ctx: &mut ExecutionCtx,
659    ) -> VortexResult<Self> {
660        Self::from_primitive_impl(parray, level, values_per_frame, false, ctx)
661    }
662
663    fn from_primitive_impl(
664        parray: &PrimitiveArray,
665        level: i32,
666        values_per_frame: usize,
667        use_dictionary: bool,
668        ctx: &mut ExecutionCtx,
669    ) -> VortexResult<Self> {
670        let byte_width = parray.ptype().byte_width();
671
672        // We compress only the valid elements.
673        let values = collect_valid_primitive(parray, ctx)?;
674        let n_values = values.len();
675        let values_per_frame = if values_per_frame > 0 {
676            values_per_frame
677        } else {
678            n_values
679        };
680
681        let value_bytes = values.buffer_handle().try_to_host_sync()?;
682        // Align frames to buffer alignment. This is necessary for overaligned buffers.
683        let alignment = *value_bytes.alignment();
684        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
685
686        let frame_byte_starts = (0..n_values * byte_width)
687            .step_by(step_width)
688            .collect::<Vec<_>>();
689        let Frames {
690            dictionary,
691            frames,
692            frame_metas,
693        } = Self::compress_values(
694            &value_bytes,
695            &frame_byte_starts,
696            level,
697            values_per_frame,
698            n_values,
699            use_dictionary,
700        )?;
701
702        let metadata = ZstdMetadata {
703            dictionary_size: dictionary
704                .as_ref()
705                .map_or(0, |dict| dict.len())
706                .try_into()?,
707            frames: frame_metas,
708        };
709
710        Ok(ZstdData::new(dictionary, frames, metadata, parray.len()))
711    }
712
713    /// Creates a ZstdArray from a VarBinView array.
714    ///
715    /// # Arguments
716    /// * `vbv` - The VarBinView array to compress
717    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
718    /// * `values_per_frame` - Number of values per frame (0 = single frame)
719    pub fn from_var_bin_view(
720        vbv: &VarBinViewArray,
721        level: i32,
722        values_per_frame: usize,
723        ctx: &mut ExecutionCtx,
724    ) -> VortexResult<Self> {
725        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true, ctx)
726    }
727
728    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
729    ///
730    /// This is useful when the compressed data will be decompressed by systems
731    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
732    ///
733    /// Note: Without a dictionary, each frame is compressed independently.
734    /// Dictionaries are trained from sample data from previously seen frames,
735    /// to improve compression ratio.
736    ///
737    /// # Arguments
738    /// * `vbv` - The VarBinView array to compress
739    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
740    /// * `values_per_frame` - Number of values per frame (0 = single frame)
741    pub fn from_var_bin_view_without_dict(
742        vbv: &VarBinViewArray,
743        level: i32,
744        values_per_frame: usize,
745        ctx: &mut ExecutionCtx,
746    ) -> VortexResult<Self> {
747        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false, ctx)
748    }
749
750    fn from_var_bin_view_impl(
751        vbv: &VarBinViewArray,
752        level: i32,
753        values_per_frame: usize,
754        use_dictionary: bool,
755        ctx: &mut ExecutionCtx,
756    ) -> VortexResult<Self> {
757        // Approach for strings: we prefix each string with its length as a u32.
758        // This is the same as what Parquet does. In some cases it may be better
759        // to separate the binary data and lengths as two separate streams, but
760        // this approach is simpler and can be best in cases when there is
761        // mutual information between strings and their lengths.
762        // We compress only the valid elements.
763        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv, ctx)?;
764        let n_values = value_byte_indices.len();
765        let values_per_frame = if values_per_frame > 0 {
766            values_per_frame
767        } else {
768            n_values
769        };
770
771        let frame_byte_starts = (0..n_values)
772            .step_by(values_per_frame)
773            .map(|i| value_byte_indices[i])
774            .collect::<Vec<_>>();
775        let Frames {
776            dictionary,
777            frames,
778            frame_metas,
779        } = Self::compress_values(
780            &value_bytes,
781            &frame_byte_starts,
782            level,
783            values_per_frame,
784            n_values,
785            use_dictionary,
786        )?;
787
788        let metadata = ZstdMetadata {
789            dictionary_size: dictionary
790                .as_ref()
791                .map_or(0, |dict| dict.len())
792                .try_into()?,
793            frames: frame_metas,
794        };
795        Ok(ZstdData::new(dictionary, frames, metadata, vbv.len()))
796    }
797
798    pub fn from_canonical(
799        canonical: &Canonical,
800        level: i32,
801        values_per_frame: usize,
802        ctx: &mut ExecutionCtx,
803    ) -> VortexResult<Option<Self>> {
804        match canonical {
805            Canonical::Primitive(parray) => Ok(Some(ZstdData::from_primitive(
806                parray,
807                level,
808                values_per_frame,
809                ctx,
810            )?)),
811            Canonical::VarBinView(vbv) => Ok(Some(ZstdData::from_var_bin_view(
812                vbv,
813                level,
814                values_per_frame,
815                ctx,
816            )?)),
817            _ => Ok(None),
818        }
819    }
820
821    pub fn from_array(
822        array: ArrayRef,
823        level: i32,
824        values_per_frame: usize,
825        ctx: &mut ExecutionCtx,
826    ) -> VortexResult<Self> {
827        let canonical = array.execute::<Canonical>(ctx)?;
828        Self::from_canonical(&canonical, level, values_per_frame, ctx)?
829            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
830    }
831
832    fn byte_width(dtype: &DType) -> usize {
833        if dtype.is_primitive() {
834            dtype.as_ptype().byte_width()
835        } else {
836            1
837        }
838    }
839
840    fn decompress(
841        &self,
842        dtype: &DType,
843        unsliced_validity: &Validity,
844        ctx: &mut ExecutionCtx,
845    ) -> VortexResult<ArrayRef> {
846        // To start, we figure out which frames we need to decompress, and with
847        // what row offset into the first such frame.
848        let byte_width = Self::byte_width(dtype);
849        let slice_n_rows = self.slice_stop - self.slice_start;
850        let slice_value_indices = unsliced_validity
851            .execute_mask(self.unsliced_n_rows, ctx)?
852            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
853
854        let slice_value_idx_start = slice_value_indices[0];
855        let slice_value_idx_stop = slice_value_indices[1];
856
857        let mut frames_to_decompress = vec![];
858        let mut value_idx_start = 0;
859        let mut uncompressed_size_to_decompress = 0;
860        let mut n_skipped_values = 0;
861        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
862            if value_idx_start >= slice_value_idx_stop {
863                break;
864            }
865
866            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
867                .vortex_expect("Uncompressed size must fit in usize");
868            let frame_n_values = if frame_meta.n_values == 0 {
869                // possibly older primitive-only metadata that just didn't store this
870                frame_uncompressed_size / byte_width
871            } else {
872                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
873            };
874
875            let value_idx_stop = value_idx_start + frame_n_values;
876            if value_idx_stop > slice_value_idx_start {
877                // we need this frame
878                frames_to_decompress.push(frame);
879                uncompressed_size_to_decompress += frame_uncompressed_size;
880            } else {
881                n_skipped_values += frame_n_values;
882            }
883            value_idx_start = value_idx_stop;
884        }
885
886        // then we actually decompress those frames
887        let mut decompressor = if let Some(dictionary) = &self.dictionary {
888            zstd::bulk::Decompressor::with_dictionary(dictionary)?
889        } else {
890            zstd::bulk::Decompressor::new()?
891        };
892        let mut decompressed = ByteBufferMut::with_capacity_aligned(
893            uncompressed_size_to_decompress,
894            Alignment::new(byte_width),
895        );
896        unsafe {
897            // safety: we immediately fill all bytes in the following loop,
898            // assuming our metadata's uncompressed size is correct
899            decompressed.set_len(uncompressed_size_to_decompress);
900        }
901        let mut uncompressed_start = 0;
902        for frame in frames_to_decompress {
903            let uncompressed_written = decompressor
904                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
905            uncompressed_start += uncompressed_written;
906        }
907        if uncompressed_start != uncompressed_size_to_decompress {
908            vortex_panic!(
909                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
910                uncompressed_size_to_decompress,
911                uncompressed_start
912            );
913        }
914
915        let decompressed = decompressed.freeze();
916        // Last, we slice the exact values requested out of the decompressed data.
917        let mut slice_validity = unsliced_validity.slice(self.slice_start..self.slice_stop)?;
918
919        // NOTE: this block handles setting the output type when the validity and DType disagree.
920        //
921        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
922        // the data frames. A ZSTD Array that was initialized must always hold onto its full
923        // validity bitmap, even if sliced to only include non-null values.
924        //
925        // We ensure that the validity of the decompressed array ALWAYS matches the validity
926        // implied by the DType.
927        if !dtype.is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
928            assert!(
929                matches!(slice_validity, Validity::AllValid),
930                "ZSTD array expects to be non-nullable but there are nulls after decompression"
931            );
932
933            slice_validity = Validity::NonNullable;
934        } else if dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
935            slice_validity = Validity::AllValid;
936        }
937        //
938        // END OF IMPORTANT BLOCK
939        //
940
941        match dtype {
942            DType::Primitive(..) => {
943                let slice_values_buffer = decompressed.slice(
944                    (slice_value_idx_start - n_skipped_values) * byte_width
945                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
946                );
947                let primitive = PrimitiveArray::from_values_byte_buffer(
948                    slice_values_buffer,
949                    dtype.as_ptype(),
950                    slice_validity,
951                    slice_n_rows,
952                );
953
954                Ok(primitive.into_array())
955            }
956            DType::Binary(_) | DType::Utf8(_) => {
957                match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
958                    AllOr::All => {
959                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
960                        let valid_views = all_views.slice(
961                            slice_value_idx_start - n_skipped_values
962                                ..slice_value_idx_stop - n_skipped_values,
963                        );
964
965                        // SAFETY: we properly construct the views inside `reconstruct_views`
966                        Ok(unsafe {
967                            VarBinViewArray::new_unchecked(
968                                valid_views,
969                                Arc::from(buffers),
970                                dtype.clone(),
971                                slice_validity,
972                            )
973                        }
974                        .into_array())
975                    }
976                    AllOr::None => Ok(ConstantArray::new(
977                        Scalar::null(dtype.clone()),
978                        slice_n_rows,
979                    )
980                    .into_array()),
981                    AllOr::Some(valid_indices) => {
982                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
983                        let valid_views = all_views.slice(
984                            slice_value_idx_start - n_skipped_values
985                                ..slice_value_idx_stop - n_skipped_values,
986                        );
987
988                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
989                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
990                            views[*index] = view
991                        }
992
993                        // SAFETY: we properly construct the views inside `reconstruct_views`
994                        Ok(unsafe {
995                            VarBinViewArray::new_unchecked(
996                                views.freeze(),
997                                Arc::from(buffers),
998                                dtype.clone(),
999                                slice_validity,
1000                            )
1001                        }
1002                        .into_array())
1003                    }
1004                }
1005            }
1006            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", dtype),
1007        }
1008    }
1009
1010    /// Returns the length of the array.
1011    #[inline]
1012    pub fn len(&self) -> usize {
1013        self.slice_stop - self.slice_start
1014    }
1015
1016    /// Returns whether the array is empty.
1017    #[inline]
1018    pub fn is_empty(&self) -> bool {
1019        self.slice_stop == self.slice_start
1020    }
1021
1022    pub fn into_parts(self, validity: Validity) -> ZstdDataParts {
1023        ZstdDataParts {
1024            dictionary: self.dictionary,
1025            frames: self.frames,
1026            metadata: self.metadata,
1027            validity,
1028            n_rows: self.unsliced_n_rows,
1029            slice_start: self.slice_start,
1030            slice_stop: self.slice_stop,
1031        }
1032    }
1033
1034    pub(crate) fn slice_start(&self) -> usize {
1035        self.slice_start
1036    }
1037
1038    pub(crate) fn slice_stop(&self) -> usize {
1039        self.slice_stop
1040    }
1041
1042    pub(crate) fn unsliced_n_rows(&self) -> usize {
1043        self.unsliced_n_rows
1044    }
1045}
1046
1047impl ValidityVTable<Zstd> for Zstd {
1048    fn validity(array: ArrayView<'_, Zstd>) -> VortexResult<Validity> {
1049        let unsliced_validity =
1050            child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1051        unsliced_validity.slice(array.slice_start()..array.slice_stop())
1052    }
1053}
1054
1055impl OperationsVTable<Zstd> for Zstd {
1056    fn scalar_at(
1057        array: ArrayView<'_, Zstd>,
1058        index: usize,
1059        ctx: &mut ExecutionCtx,
1060    ) -> VortexResult<Scalar> {
1061        let unsliced_validity =
1062            child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1063        let sliced = array.data().with_slice(index, index + 1);
1064        sliced
1065            .decompress(array.dtype(), &unsliced_validity, ctx)?
1066            .execute_scalar(0, ctx)
1067    }
1068}
1069
1070#[cfg(test)]
1071#[expect(clippy::cast_possible_truncation)]
1072mod tests {
1073    use vortex_buffer::ByteBuffer;
1074
1075    use super::reconstruct_views;
1076    use crate::array::BinaryView;
1077
1078    /// Build a Zstd-style interleaved buffer: [u32-LE length][string bytes] repeated.
1079    fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
1080        let mut buf = Vec::new();
1081        for s in strings {
1082            let len = s.len() as u32;
1083            buf.extend_from_slice(&len.to_le_bytes());
1084            buf.extend_from_slice(s);
1085        }
1086        ByteBuffer::copy_from(buf.as_slice())
1087    }
1088
1089    #[test]
1090    fn test_reconstruct_views_no_split() {
1091        let strings: &[&[u8]] = &[b"hello", b"world"];
1092        let buf = make_interleaved(strings);
1093        let (buffers, views) = reconstruct_views(&buf, 1024);
1094
1095        assert_eq!(buffers.len(), 1);
1096        assert_eq!(views.len(), 2);
1097        // Each entry: [u32 len (4 bytes)][data], so offsets are 4 and 4+5+4=13
1098        assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1099        assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1100    }
1101
1102    #[test]
1103    fn test_reconstruct_views_split_across_segments() {
1104        // "aaaaaaaaaaaaa" (13 bytes) and "bbbbbbbbbbbbb" (13 bytes).
1105        // Each entry occupies 4 (length prefix) + 13 (data) = 17 bytes.
1106        // With max_buffer_len=20, the second entry's data (offset 4+13+4=21) exceeds the limit,
1107        // so it rolls into a second segment.
1108        let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1109        let buf = make_interleaved(strings);
1110        let (buffers, views) = reconstruct_views(&buf, 20);
1111
1112        assert_eq!(buffers.len(), 2);
1113        assert_eq!(views.len(), 2);
1114        assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1115        // Second entry starts a new segment at byte 17 (the length prefix), so local offset = 4.
1116        assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1117    }
1118}