Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use itertools::Itertools as _;
12use prost::Message as _;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayId;
17use vortex_array::ArrayParts;
18use vortex_array::ArrayRef;
19use vortex_array::ArrayView;
20use vortex_array::Canonical;
21use vortex_array::ExecutionCtx;
22use vortex_array::ExecutionResult;
23use vortex_array::IntoArray;
24use vortex_array::Precision;
25use vortex_array::accessor::ArrayAccessor;
26use vortex_array::arrays::ConstantArray;
27use vortex_array::arrays::PrimitiveArray;
28use vortex_array::arrays::VarBinViewArray;
29use vortex_array::arrays::varbinview::build_views::BinaryView;
30use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
31use vortex_array::buffer::BufferHandle;
32use vortex_array::dtype::DType;
33use vortex_array::scalar::Scalar;
34use vortex_array::serde::ArrayChildren;
35use vortex_array::smallvec::smallvec;
36use vortex_array::validity::Validity;
37use vortex_array::vtable::OperationsVTable;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityVTable;
40use vortex_array::vtable::child_to_validity;
41use vortex_array::vtable::validity_to_child;
42use vortex_buffer::Alignment;
43use vortex_buffer::Buffer;
44use vortex_buffer::BufferMut;
45use vortex_buffer::ByteBuffer;
46use vortex_buffer::ByteBufferMut;
47use vortex_error::VortexError;
48use vortex_error::VortexExpect;
49use vortex_error::VortexResult;
50use vortex_error::vortex_bail;
51use vortex_error::vortex_ensure;
52use vortex_error::vortex_err;
53use vortex_error::vortex_panic;
54use vortex_mask::AllOr;
55use vortex_session::VortexSession;
56use vortex_session::registry::CachedId;
57
58use crate::ZstdFrameMetadata;
59use crate::ZstdMetadata;
60
61// Zstd doesn't support training dictionaries on very few samples.
62const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
63type ViewLen = u32;
64
65// Overall approach here:
66// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
67// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
68// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
69// want somewhat faster access to slices or individual rows, allowing us to only
70// decompress the necessary frames.
71
72// Visually, during decompression, we have an interval of frames we're
73// decompressing and a tighter interval of the slice we actually care about.
74// |=============values (all valid elements)==============|
75// |<-skipped_uncompressed->|----decompressed-------------|
76//                              |------slice-------|
77//                              ^                  ^
78// |<-slice_uncompressed_start->|                  |
79// |<------------slice_uncompressed_stop---------->|
80// We then insert these values to the correct position using a primitive array
81// constructor.
82
83/// A [`Zstd`]-encoded Vortex array.
84pub type ZstdArray = Array<Zstd>;
85
86impl ArrayHash for ZstdData {
87    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
88        match &self.dictionary {
89            Some(dict) => {
90                true.hash(state);
91                dict.array_hash(state, precision);
92            }
93            None => {
94                false.hash(state);
95            }
96        }
97        for frame in &self.frames {
98            frame.array_hash(state, precision);
99        }
100        self.unsliced_n_rows.hash(state);
101        self.slice_start.hash(state);
102        self.slice_stop.hash(state);
103    }
104}
105
106impl ArrayEq for ZstdData {
107    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
108        if !match (&self.dictionary, &other.dictionary) {
109            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
110            (None, None) => true,
111            _ => false,
112        } {
113            return false;
114        }
115        if self.frames.len() != other.frames.len() {
116            return false;
117        }
118        for (a, b) in self.frames.iter().zip(&other.frames) {
119            if !a.array_eq(b, precision) {
120                return false;
121            }
122        }
123        self.unsliced_n_rows == other.unsliced_n_rows
124            && self.slice_start == other.slice_start
125            && self.slice_stop == other.slice_stop
126    }
127}
128
129impl VTable for Zstd {
130    type TypedArrayData = ZstdData;
131
132    type OperationsVTable = Self;
133    type ValidityVTable = Self;
134
135    fn id(&self) -> ArrayId {
136        static ID: CachedId = CachedId::new("vortex.zstd");
137        *ID
138    }
139
140    fn validate(
141        &self,
142        data: &Self::TypedArrayData,
143        dtype: &DType,
144        len: usize,
145        slots: &[Option<ArrayRef>],
146    ) -> VortexResult<()> {
147        let validity = child_to_validity(slots[0].as_ref(), dtype.nullability());
148        data.validate(dtype, len, &validity)
149    }
150
151    fn nbuffers(array: ArrayView<'_, Self>) -> usize {
152        array.dictionary.is_some() as usize + array.frames.len()
153    }
154
155    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
156        if let Some(dict) = &array.dictionary {
157            if idx == 0 {
158                return BufferHandle::new_host(dict.clone());
159            }
160            BufferHandle::new_host(array.frames[idx - 1].clone())
161        } else {
162            BufferHandle::new_host(array.frames[idx].clone())
163        }
164    }
165
166    fn buffer_name(array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
167        if array.dictionary.is_some() {
168            if idx == 0 {
169                Some("dictionary".to_string())
170            } else {
171                Some(format!("frame_{}", idx - 1))
172            }
173        } else {
174            Some(format!("frame_{idx}"))
175        }
176    }
177
178    fn serialize(
179        array: ArrayView<'_, Self>,
180        _session: &VortexSession,
181    ) -> VortexResult<Option<Vec<u8>>> {
182        Ok(Some(array.metadata.clone().encode_to_vec()))
183    }
184
185    fn deserialize(
186        &self,
187        dtype: &DType,
188        len: usize,
189        metadata: &[u8],
190        buffers: &[BufferHandle],
191        children: &dyn ArrayChildren,
192        _session: &VortexSession,
193    ) -> VortexResult<ArrayParts<Self>> {
194        let metadata = ZstdMetadata::decode(metadata)?;
195        let validity = if children.is_empty() {
196            Validity::from(dtype.nullability())
197        } else if children.len() == 1 {
198            let validity = children.get(0, &Validity::DTYPE, len)?;
199            Validity::Array(validity)
200        } else {
201            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
202        };
203
204        let (dictionary_buffer, compressed_buffers) = if metadata.dictionary_size == 0 {
205            // no dictionary
206            (
207                None,
208                buffers
209                    .iter()
210                    .map(|b| b.clone().try_to_host_sync())
211                    .collect::<VortexResult<Vec<_>>>()?,
212            )
213        } else {
214            // with dictionary
215            (
216                Some(buffers[0].clone().try_to_host_sync()?),
217                buffers[1..]
218                    .iter()
219                    .map(|b| b.clone().try_to_host_sync())
220                    .collect::<VortexResult<Vec<_>>>()?,
221            )
222        };
223
224        let slots = smallvec![validity_to_child(&validity, len)];
225        let data = ZstdData::new(dictionary_buffer, compressed_buffers, metadata, len);
226        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
227    }
228
229    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
230        SLOT_NAMES[idx].to_string()
231    }
232
233    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
234        let unsliced_validity = child_to_validity(
235            array.as_ref().slots()[0].as_ref(),
236            array.dtype().nullability(),
237        );
238        array
239            .data()
240            .decompress(array.dtype(), &unsliced_validity, ctx)?
241            .execute::<ArrayRef>(ctx)
242            .map(ExecutionResult::done)
243    }
244
245    fn reduce_parent(
246        array: ArrayView<'_, Self>,
247        parent: &ArrayRef,
248        child_idx: usize,
249    ) -> VortexResult<Option<ArrayRef>> {
250        crate::rules::RULES.evaluate(array, parent, child_idx)
251    }
252}
253
254#[derive(Clone, Debug)]
255pub struct Zstd;
256
257impl Zstd {
258    pub fn try_new(dtype: DType, data: ZstdData, validity: Validity) -> VortexResult<ZstdArray> {
259        let len = data.len();
260        data.validate(&dtype, len, &validity)?;
261        let slots = smallvec![validity_to_child(&validity, data.unsliced_n_rows())];
262        Ok(unsafe {
263            Array::from_parts_unchecked(ArrayParts::new(Zstd, dtype, len, data).with_slots(slots))
264        })
265    }
266
267    /// Compress a [`VarBinViewArray`] using Zstd without a dictionary.
268    pub fn from_var_bin_view_without_dict(
269        vbv: &VarBinViewArray,
270        level: i32,
271        values_per_frame: usize,
272        ctx: &mut ExecutionCtx,
273    ) -> VortexResult<ZstdArray> {
274        let validity = vbv.validity()?;
275        Self::try_new(
276            vbv.dtype().clone(),
277            ZstdData::from_var_bin_view_without_dict(vbv, level, values_per_frame, ctx)?,
278            validity,
279        )
280    }
281
282    /// Compress a [`PrimitiveArray`] using Zstd.
283    pub fn from_primitive(
284        parray: &PrimitiveArray,
285        level: i32,
286        values_per_frame: usize,
287        ctx: &mut ExecutionCtx,
288    ) -> VortexResult<ZstdArray> {
289        let validity = parray.validity()?;
290        Self::try_new(
291            parray.dtype().clone(),
292            ZstdData::from_primitive(parray, level, values_per_frame, ctx)?,
293            validity,
294        )
295    }
296
297    /// Compress a [`VarBinViewArray`] using Zstd.
298    pub fn from_var_bin_view(
299        vbv: &VarBinViewArray,
300        level: i32,
301        values_per_frame: usize,
302        ctx: &mut ExecutionCtx,
303    ) -> VortexResult<ZstdArray> {
304        let validity = vbv.validity()?;
305        Self::try_new(
306            vbv.dtype().clone(),
307            ZstdData::from_var_bin_view(vbv, level, values_per_frame, ctx)?,
308            validity,
309        )
310    }
311
312    pub fn decompress(array: &ZstdArray, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
313        let unsliced_validity = child_to_validity(
314            array.as_ref().slots()[0].as_ref(),
315            array.dtype().nullability(),
316        );
317        array
318            .data()
319            .decompress(array.dtype(), &unsliced_validity, ctx)
320    }
321}
322
323/// The validity bitmap indicating which elements are non-null.
324pub(super) const NUM_SLOTS: usize = 1;
325pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
326
327#[derive(Clone, Debug)]
328pub struct ZstdData {
329    pub(crate) dictionary: Option<ByteBuffer>,
330    pub(crate) frames: Vec<ByteBuffer>,
331    pub(crate) metadata: ZstdMetadata,
332    unsliced_n_rows: usize,
333    slice_start: usize,
334    slice_stop: usize,
335}
336
337impl Display for ZstdData {
338    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
339        write!(
340            f,
341            "nrows: {}, slice: {}..{}",
342            self.unsliced_n_rows, self.slice_start, self.slice_stop
343        )
344    }
345}
346
347pub struct ZstdDataParts {
348    pub dictionary: Option<ByteBuffer>,
349    pub frames: Vec<ByteBuffer>,
350    pub metadata: ZstdMetadata,
351    pub validity: Validity,
352    pub n_rows: usize,
353    pub slice_start: usize,
354    pub slice_stop: usize,
355}
356
357/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
358#[derive(Debug)]
359struct Frames {
360    dictionary: Option<ByteBuffer>,
361    frames: Vec<ByteBuffer>,
362    frame_metas: Vec<ZstdFrameMetadata>,
363}
364
365fn choose_max_dict_size(uncompressed_size: usize) -> usize {
366    // following recommendations from
367    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
368    // that is, 1/100 the data size, up to 100kB.
369    // It appears that zstd can't train dictionaries with <256 bytes.
370    (uncompressed_size / 100).clamp(256, 100 * 1024)
371}
372
373fn collect_valid_primitive(
374    parray: &PrimitiveArray,
375    ctx: &mut ExecutionCtx,
376) -> VortexResult<PrimitiveArray> {
377    let mask = parray
378        .as_ref()
379        .validity()?
380        .execute_mask(parray.as_ref().len(), ctx)?;
381    let result = parray.filter(mask)?.execute::<PrimitiveArray>(ctx)?;
382    Ok(result)
383}
384
385fn collect_valid_vbv(
386    vbv: &VarBinViewArray,
387    ctx: &mut ExecutionCtx,
388) -> VortexResult<(ByteBuffer, Vec<usize>)> {
389    let mask = vbv
390        .as_ref()
391        .validity()?
392        .execute_mask(vbv.as_ref().len(), ctx)?;
393    let buffer_and_value_byte_indices = match mask.bit_buffer() {
394        AllOr::None => (Buffer::empty(), Vec::new()),
395        _ => {
396            let mut buffer = BufferMut::with_capacity(
397                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
398                    + mask.true_count() * size_of::<ViewLen>(),
399            );
400            let mut value_byte_indices = Vec::new();
401            vbv.with_iterator(|iterator| {
402                // by flattening, we should omit nulls
403                for value in iterator.flatten() {
404                    value_byte_indices.push(buffer.len());
405                    // here's where we write the string lengths
406                    buffer
407                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
408                    buffer.extend_from_slice(value);
409                }
410                Ok::<_, VortexError>(())
411            })?;
412            (buffer.freeze(), value_byte_indices)
413        }
414    };
415    Ok(buffer_and_value_byte_indices)
416}
417
418/// Reconstruct BinaryView structs from length-prefixed byte data.
419///
420/// The buffer contains interleaved u32 lengths (little-endian) and string data.
421/// When the cumulative data exceeds `max_buffer_len`, the buffer is split (zero-copy) into
422/// multiple segments so that BinaryView's u32 offsets can address all data.
423///
424/// Pass [`MAX_BUFFER_LEN`] for `max_buffer_len` in production; a smaller value can be used in
425/// tests to exercise the splitting path without allocating >2 GiB.
426pub fn reconstruct_views(
427    buffer: &ByteBuffer,
428    max_buffer_len: usize,
429) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
430    let mut views = BufferMut::<BinaryView>::empty();
431    let mut buffers = Vec::new();
432    let mut segment_start: usize = 0;
433    let mut offset = 0;
434
435    while offset < buffer.len() {
436        let str_len = ViewLen::from_le_bytes(
437            buffer
438                .get(offset..offset + size_of::<ViewLen>())
439                .vortex_expect("corrupted zstd length")
440                .try_into()
441                .ok()
442                .vortex_expect("must fit ViewLen size"),
443        ) as usize;
444
445        let value_data_offset = offset + size_of::<ViewLen>();
446        let local_offset = value_data_offset - segment_start;
447
448        if local_offset + str_len > max_buffer_len && offset > segment_start {
449            buffers.push(buffer.slice(segment_start..offset));
450            segment_start = offset;
451        }
452
453        let local_offset = u32::try_from(value_data_offset - segment_start)
454            .vortex_expect("local offset within segment must fit in u32");
455        let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
456        let value = &buffer[value_data_offset..value_data_offset + str_len];
457        views.push(BinaryView::make_view(value, buf_index, local_offset));
458        offset = value_data_offset + str_len;
459    }
460
461    if segment_start < buffer.len() {
462        buffers.push(buffer.slice(segment_start..buffer.len()));
463    }
464
465    (buffers, views.freeze())
466}
467
468impl ZstdData {
469    pub fn new(
470        dictionary: Option<ByteBuffer>,
471        frames: Vec<ByteBuffer>,
472        metadata: ZstdMetadata,
473        n_rows: usize,
474    ) -> Self {
475        Self {
476            dictionary,
477            frames,
478            metadata,
479            unsliced_n_rows: n_rows,
480            slice_start: 0,
481            slice_stop: n_rows,
482        }
483    }
484
485    pub fn validate(&self, dtype: &DType, len: usize, validity: &Validity) -> VortexResult<()> {
486        vortex_ensure!(
487            matches!(
488                dtype,
489                DType::Primitive(..) | DType::Binary(_) | DType::Utf8(_)
490            ),
491            "Unsupported dtype for Zstd array: {dtype}"
492        );
493        vortex_ensure!(
494            self.slice_start <= self.slice_stop,
495            "Invalid slice range {}..{}",
496            self.slice_start,
497            self.slice_stop
498        );
499        vortex_ensure!(
500            self.slice_stop <= self.unsliced_n_rows,
501            "Slice stop {} exceeds unsliced row count {}",
502            self.slice_stop,
503            self.unsliced_n_rows
504        );
505        vortex_ensure!(
506            self.slice_stop - self.slice_start == len,
507            "Slice length {} does not match array length {}",
508            self.slice_stop - self.slice_start,
509            len
510        );
511        if let Some(validity_len) = validity.maybe_len() {
512            vortex_ensure!(
513                validity_len == self.unsliced_n_rows,
514                "Validity length {} does not match unsliced row count {}",
515                validity_len,
516                self.unsliced_n_rows
517            );
518        }
519
520        match &self.dictionary {
521            Some(dictionary) => vortex_ensure!(
522                usize::try_from(self.metadata.dictionary_size)? == dictionary.len(),
523                "Dictionary size metadata {} does not match buffer size {}",
524                self.metadata.dictionary_size,
525                dictionary.len()
526            ),
527            None => vortex_ensure!(
528                self.metadata.dictionary_size == 0,
529                "Dictionary metadata present without dictionary buffer"
530            ),
531        }
532        vortex_ensure!(
533            self.frames.len() == self.metadata.frames.len(),
534            "Frame count {} does not match metadata frame count {}",
535            self.frames.len(),
536            self.metadata.frames.len()
537        );
538
539        Ok(())
540    }
541
542    pub(crate) fn with_slice(&self, start: usize, stop: usize) -> Self {
543        let new_start = self.slice_start + start;
544        let new_stop = self.slice_start + stop;
545
546        assert!(
547            new_start <= self.slice_stop,
548            "new slice start {new_start} exceeds end {}",
549            self.slice_stop
550        );
551
552        assert!(
553            new_stop <= self.slice_stop,
554            "new slice stop {new_stop} exceeds end {}",
555            self.slice_stop
556        );
557
558        Self {
559            slice_start: new_start,
560            slice_stop: new_stop,
561            ..self.clone()
562        }
563    }
564
565    fn compress_values(
566        value_bytes: &ByteBuffer,
567        frame_byte_starts: &[usize],
568        level: i32,
569        values_per_frame: usize,
570        n_values: usize,
571        use_dictionary: bool,
572    ) -> VortexResult<Frames> {
573        let n_frames = frame_byte_starts.len();
574
575        // Would-be sample sizes if we end up applying zstd dictionary
576        let mut sample_sizes = Vec::with_capacity(n_frames);
577        for i in 0..n_frames {
578            let frame_byte_end = frame_byte_starts
579                .get(i + 1)
580                .copied()
581                .unwrap_or(value_bytes.len());
582            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
583        }
584        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
585
586        let (dictionary, mut compressor) = if !use_dictionary
587            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
588        {
589            // no dictionary
590            (None, zstd::bulk::Compressor::new(level)?)
591        } else {
592            // with dictionary
593            let max_dict_size = choose_max_dict_size(value_bytes.len());
594            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
595                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
596
597            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
598            (Some(ByteBuffer::from(dict)), compressor)
599        };
600
601        let mut frame_metas = vec![];
602        let mut frames = vec![];
603        for i in 0..n_frames {
604            let frame_byte_end = frame_byte_starts
605                .get(i + 1)
606                .copied()
607                .unwrap_or(value_bytes.len());
608
609            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
610            let compressed = compressor
611                .compress(uncompressed)
612                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
613            frame_metas.push(ZstdFrameMetadata {
614                uncompressed_size: uncompressed.len() as u64,
615                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
616            });
617            frames.push(ByteBuffer::from(compressed));
618        }
619
620        Ok(Frames {
621            dictionary,
622            frames,
623            frame_metas,
624        })
625    }
626
627    /// Creates a ZstdArray from a primitive array.
628    ///
629    /// # Arguments
630    /// * `parray` - The primitive array to compress
631    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
632    /// * `values_per_frame` - Number of values per frame (0 = single frame)
633    pub fn from_primitive(
634        parray: &PrimitiveArray,
635        level: i32,
636        values_per_frame: usize,
637        ctx: &mut ExecutionCtx,
638    ) -> VortexResult<Self> {
639        Self::from_primitive_impl(parray, level, values_per_frame, true, ctx)
640    }
641
642    /// Creates a ZstdArray from a primitive array without using a dictionary.
643    ///
644    /// This is useful when the compressed data will be decompressed by systems
645    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
646    ///
647    /// Note: Without a dictionary, each frame is compressed independently.
648    /// Dictionaries are trained from sample data from previously seen frames,
649    /// to improve compression ratio.
650    ///
651    /// # Arguments
652    /// * `parray` - The primitive array to compress
653    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
654    /// * `values_per_frame` - Number of values per frame (0 = single frame)
655    pub fn from_primitive_without_dict(
656        parray: &PrimitiveArray,
657        level: i32,
658        values_per_frame: usize,
659        ctx: &mut ExecutionCtx,
660    ) -> VortexResult<Self> {
661        Self::from_primitive_impl(parray, level, values_per_frame, false, ctx)
662    }
663
664    fn from_primitive_impl(
665        parray: &PrimitiveArray,
666        level: i32,
667        values_per_frame: usize,
668        use_dictionary: bool,
669        ctx: &mut ExecutionCtx,
670    ) -> VortexResult<Self> {
671        let byte_width = parray.ptype().byte_width();
672
673        // We compress only the valid elements.
674        let values = collect_valid_primitive(parray, ctx)?;
675        let n_values = values.len();
676        let values_per_frame = if values_per_frame > 0 {
677            values_per_frame
678        } else {
679            n_values
680        };
681
682        let value_bytes = values.buffer_handle().try_to_host_sync()?;
683        // Align frames to buffer alignment. This is necessary for overaligned buffers.
684        let alignment = *value_bytes.alignment();
685        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
686
687        let frame_byte_starts = (0..n_values * byte_width)
688            .step_by(step_width)
689            .collect::<Vec<_>>();
690        let Frames {
691            dictionary,
692            frames,
693            frame_metas,
694        } = Self::compress_values(
695            &value_bytes,
696            &frame_byte_starts,
697            level,
698            values_per_frame,
699            n_values,
700            use_dictionary,
701        )?;
702
703        let metadata = ZstdMetadata {
704            dictionary_size: dictionary
705                .as_ref()
706                .map_or(0, |dict| dict.len())
707                .try_into()?,
708            frames: frame_metas,
709        };
710
711        Ok(ZstdData::new(dictionary, frames, metadata, parray.len()))
712    }
713
714    /// Creates a ZstdArray from a VarBinView array.
715    ///
716    /// # Arguments
717    /// * `vbv` - The VarBinView array to compress
718    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
719    /// * `values_per_frame` - Number of values per frame (0 = single frame)
720    pub fn from_var_bin_view(
721        vbv: &VarBinViewArray,
722        level: i32,
723        values_per_frame: usize,
724        ctx: &mut ExecutionCtx,
725    ) -> VortexResult<Self> {
726        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true, ctx)
727    }
728
729    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
730    ///
731    /// This is useful when the compressed data will be decompressed by systems
732    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
733    ///
734    /// Note: Without a dictionary, each frame is compressed independently.
735    /// Dictionaries are trained from sample data from previously seen frames,
736    /// to improve compression ratio.
737    ///
738    /// # Arguments
739    /// * `vbv` - The VarBinView array to compress
740    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
741    /// * `values_per_frame` - Number of values per frame (0 = single frame)
742    pub fn from_var_bin_view_without_dict(
743        vbv: &VarBinViewArray,
744        level: i32,
745        values_per_frame: usize,
746        ctx: &mut ExecutionCtx,
747    ) -> VortexResult<Self> {
748        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false, ctx)
749    }
750
751    fn from_var_bin_view_impl(
752        vbv: &VarBinViewArray,
753        level: i32,
754        values_per_frame: usize,
755        use_dictionary: bool,
756        ctx: &mut ExecutionCtx,
757    ) -> VortexResult<Self> {
758        // Approach for strings: we prefix each string with its length as a u32.
759        // This is the same as what Parquet does. In some cases it may be better
760        // to separate the binary data and lengths as two separate streams, but
761        // this approach is simpler and can be best in cases when there is
762        // mutual information between strings and their lengths.
763        // We compress only the valid elements.
764        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv, ctx)?;
765        let n_values = value_byte_indices.len();
766        let values_per_frame = if values_per_frame > 0 {
767            values_per_frame
768        } else {
769            n_values
770        };
771
772        let frame_byte_starts = (0..n_values)
773            .step_by(values_per_frame)
774            .map(|i| value_byte_indices[i])
775            .collect::<Vec<_>>();
776        let Frames {
777            dictionary,
778            frames,
779            frame_metas,
780        } = Self::compress_values(
781            &value_bytes,
782            &frame_byte_starts,
783            level,
784            values_per_frame,
785            n_values,
786            use_dictionary,
787        )?;
788
789        let metadata = ZstdMetadata {
790            dictionary_size: dictionary
791                .as_ref()
792                .map_or(0, |dict| dict.len())
793                .try_into()?,
794            frames: frame_metas,
795        };
796        Ok(ZstdData::new(dictionary, frames, metadata, vbv.len()))
797    }
798
799    pub fn from_canonical(
800        canonical: &Canonical,
801        level: i32,
802        values_per_frame: usize,
803        ctx: &mut ExecutionCtx,
804    ) -> VortexResult<Option<Self>> {
805        match canonical {
806            Canonical::Primitive(parray) => Ok(Some(ZstdData::from_primitive(
807                parray,
808                level,
809                values_per_frame,
810                ctx,
811            )?)),
812            Canonical::VarBinView(vbv) => Ok(Some(ZstdData::from_var_bin_view(
813                vbv,
814                level,
815                values_per_frame,
816                ctx,
817            )?)),
818            _ => Ok(None),
819        }
820    }
821
822    pub fn from_array(
823        array: ArrayRef,
824        level: i32,
825        values_per_frame: usize,
826        ctx: &mut ExecutionCtx,
827    ) -> VortexResult<Self> {
828        let canonical = array.execute::<Canonical>(ctx)?;
829        Self::from_canonical(&canonical, level, values_per_frame, ctx)?
830            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
831    }
832
833    fn byte_width(dtype: &DType) -> usize {
834        if dtype.is_primitive() {
835            dtype.as_ptype().byte_width()
836        } else {
837            1
838        }
839    }
840
841    fn decompress(
842        &self,
843        dtype: &DType,
844        unsliced_validity: &Validity,
845        ctx: &mut ExecutionCtx,
846    ) -> VortexResult<ArrayRef> {
847        // To start, we figure out which frames we need to decompress, and with
848        // what row offset into the first such frame.
849        let byte_width = Self::byte_width(dtype);
850        let slice_n_rows = self.slice_stop - self.slice_start;
851        let slice_value_indices = unsliced_validity
852            .execute_mask(self.unsliced_n_rows, ctx)?
853            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
854
855        let slice_value_idx_start = slice_value_indices[0];
856        let slice_value_idx_stop = slice_value_indices[1];
857
858        let mut frames_to_decompress = vec![];
859        let mut value_idx_start = 0;
860        let mut uncompressed_size_to_decompress = 0;
861        let mut n_skipped_values = 0;
862        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
863            if value_idx_start >= slice_value_idx_stop {
864                break;
865            }
866
867            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
868                .vortex_expect("Uncompressed size must fit in usize");
869            let frame_n_values = if frame_meta.n_values == 0 {
870                // possibly older primitive-only metadata that just didn't store this
871                frame_uncompressed_size / byte_width
872            } else {
873                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
874            };
875
876            let value_idx_stop = value_idx_start + frame_n_values;
877            if value_idx_stop > slice_value_idx_start {
878                // we need this frame
879                frames_to_decompress.push(frame);
880                uncompressed_size_to_decompress += frame_uncompressed_size;
881            } else {
882                n_skipped_values += frame_n_values;
883            }
884            value_idx_start = value_idx_stop;
885        }
886
887        // then we actually decompress those frames
888        let mut decompressor = if let Some(dictionary) = &self.dictionary {
889            zstd::bulk::Decompressor::with_dictionary(dictionary)?
890        } else {
891            zstd::bulk::Decompressor::new()?
892        };
893        let mut decompressed = ByteBufferMut::with_capacity_aligned(
894            uncompressed_size_to_decompress,
895            Alignment::new(byte_width),
896        );
897        unsafe {
898            // safety: we immediately fill all bytes in the following loop,
899            // assuming our metadata's uncompressed size is correct
900            decompressed.set_len(uncompressed_size_to_decompress);
901        }
902        let mut uncompressed_start = 0;
903        for frame in frames_to_decompress {
904            let uncompressed_written = decompressor
905                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
906            uncompressed_start += uncompressed_written;
907        }
908        if uncompressed_start != uncompressed_size_to_decompress {
909            vortex_panic!(
910                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
911                uncompressed_size_to_decompress,
912                uncompressed_start
913            );
914        }
915
916        let decompressed = decompressed.freeze();
917        // Last, we slice the exact values requested out of the decompressed data.
918        let mut slice_validity = unsliced_validity.slice(self.slice_start..self.slice_stop)?;
919
920        // NOTE: this block handles setting the output type when the validity and DType disagree.
921        //
922        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
923        // the data frames. A ZSTD Array that was initialized must always hold onto its full
924        // validity bitmap, even if sliced to only include non-null values.
925        //
926        // We ensure that the validity of the decompressed array ALWAYS matches the validity
927        // implied by the DType.
928        if !dtype.is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
929            assert!(
930                matches!(slice_validity, Validity::AllValid),
931                "ZSTD array expects to be non-nullable but there are nulls after decompression"
932            );
933
934            slice_validity = Validity::NonNullable;
935        } else if dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
936            slice_validity = Validity::AllValid;
937        }
938        //
939        // END OF IMPORTANT BLOCK
940        //
941
942        match dtype {
943            DType::Primitive(..) => {
944                let slice_values_buffer = decompressed.slice(
945                    (slice_value_idx_start - n_skipped_values) * byte_width
946                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
947                );
948                let primitive = PrimitiveArray::from_values_byte_buffer(
949                    slice_values_buffer,
950                    dtype.as_ptype(),
951                    slice_validity,
952                    slice_n_rows,
953                );
954
955                Ok(primitive.into_array())
956            }
957            DType::Binary(_) | DType::Utf8(_) => {
958                match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
959                    AllOr::All => {
960                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
961                        let valid_views = all_views.slice(
962                            slice_value_idx_start - n_skipped_values
963                                ..slice_value_idx_stop - n_skipped_values,
964                        );
965
966                        // SAFETY: we properly construct the views inside `reconstruct_views`
967                        Ok(unsafe {
968                            VarBinViewArray::new_unchecked(
969                                valid_views,
970                                Arc::from(buffers),
971                                dtype.clone(),
972                                slice_validity,
973                            )
974                        }
975                        .into_array())
976                    }
977                    AllOr::None => Ok(ConstantArray::new(
978                        Scalar::null(dtype.clone()),
979                        slice_n_rows,
980                    )
981                    .into_array()),
982                    AllOr::Some(valid_indices) => {
983                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
984                        let valid_views = all_views.slice(
985                            slice_value_idx_start - n_skipped_values
986                                ..slice_value_idx_stop - n_skipped_values,
987                        );
988
989                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
990                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
991                            views[*index] = view
992                        }
993
994                        // SAFETY: we properly construct the views inside `reconstruct_views`
995                        Ok(unsafe {
996                            VarBinViewArray::new_unchecked(
997                                views.freeze(),
998                                Arc::from(buffers),
999                                dtype.clone(),
1000                                slice_validity,
1001                            )
1002                        }
1003                        .into_array())
1004                    }
1005                }
1006            }
1007            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", dtype),
1008        }
1009    }
1010
1011    /// Returns the length of the array.
1012    #[inline]
1013    pub fn len(&self) -> usize {
1014        self.slice_stop - self.slice_start
1015    }
1016
1017    /// Returns whether the array is empty.
1018    #[inline]
1019    pub fn is_empty(&self) -> bool {
1020        self.slice_stop == self.slice_start
1021    }
1022
1023    pub fn into_parts(self, validity: Validity) -> ZstdDataParts {
1024        ZstdDataParts {
1025            dictionary: self.dictionary,
1026            frames: self.frames,
1027            metadata: self.metadata,
1028            validity,
1029            n_rows: self.unsliced_n_rows,
1030            slice_start: self.slice_start,
1031            slice_stop: self.slice_stop,
1032        }
1033    }
1034
1035    pub(crate) fn slice_start(&self) -> usize {
1036        self.slice_start
1037    }
1038
1039    pub(crate) fn slice_stop(&self) -> usize {
1040        self.slice_stop
1041    }
1042
1043    pub(crate) fn unsliced_n_rows(&self) -> usize {
1044        self.unsliced_n_rows
1045    }
1046}
1047
1048impl ValidityVTable<Zstd> for Zstd {
1049    fn validity(array: ArrayView<'_, Zstd>) -> VortexResult<Validity> {
1050        let unsliced_validity =
1051            child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1052        unsliced_validity.slice(array.slice_start()..array.slice_stop())
1053    }
1054}
1055
1056impl OperationsVTable<Zstd> for Zstd {
1057    fn scalar_at(
1058        array: ArrayView<'_, Zstd>,
1059        index: usize,
1060        ctx: &mut ExecutionCtx,
1061    ) -> VortexResult<Scalar> {
1062        let unsliced_validity =
1063            child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1064        let sliced = array.data().with_slice(index, index + 1);
1065        sliced
1066            .decompress(array.dtype(), &unsliced_validity, ctx)?
1067            .execute_scalar(0, ctx)
1068    }
1069}
1070
1071#[cfg(test)]
1072#[expect(clippy::cast_possible_truncation)]
1073mod tests {
1074    use vortex_buffer::ByteBuffer;
1075
1076    use super::reconstruct_views;
1077    use crate::array::BinaryView;
1078
1079    /// Build a Zstd-style interleaved buffer: [u32-LE length][string bytes] repeated.
1080    fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
1081        let mut buf = Vec::new();
1082        for s in strings {
1083            let len = s.len() as u32;
1084            buf.extend_from_slice(&len.to_le_bytes());
1085            buf.extend_from_slice(s);
1086        }
1087        ByteBuffer::copy_from(buf.as_slice())
1088    }
1089
1090    #[test]
1091    fn test_reconstruct_views_no_split() {
1092        let strings: &[&[u8]] = &[b"hello", b"world"];
1093        let buf = make_interleaved(strings);
1094        let (buffers, views) = reconstruct_views(&buf, 1024);
1095
1096        assert_eq!(buffers.len(), 1);
1097        assert_eq!(views.len(), 2);
1098        // Each entry: [u32 len (4 bytes)][data], so offsets are 4 and 4+5+4=13
1099        assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1100        assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1101    }
1102
1103    #[test]
1104    fn test_reconstruct_views_split_across_segments() {
1105        // "aaaaaaaaaaaaa" (13 bytes) and "bbbbbbbbbbbbb" (13 bytes).
1106        // Each entry occupies 4 (length prefix) + 13 (data) = 17 bytes.
1107        // With max_buffer_len=20, the second entry's data (offset 4+13+4=21) exceeds the limit,
1108        // so it rolls into a second segment.
1109        let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1110        let buf = make_interleaved(strings);
1111        let (buffers, views) = reconstruct_views(&buf, 20);
1112
1113        assert_eq!(buffers.len(), 2);
1114        assert_eq!(views.len(), 2);
1115        assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1116        // Second entry starts a new segment at byte 17 (the length prefix), so local offset = 4.
1117        assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1118    }
1119}