Skip to main content

vortex_zstd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use itertools::Itertools as _;
12use prost::Message as _;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayId;
17use vortex_array::ArrayParts;
18use vortex_array::ArrayRef;
19use vortex_array::ArrayView;
20use vortex_array::Canonical;
21use vortex_array::ExecutionCtx;
22use vortex_array::ExecutionResult;
23use vortex_array::IntoArray;
24use vortex_array::LEGACY_SESSION;
25use vortex_array::Precision;
26use vortex_array::ToCanonical;
27use vortex_array::VortexSessionExecute;
28use vortex_array::accessor::ArrayAccessor;
29use vortex_array::arrays::ConstantArray;
30use vortex_array::arrays::PrimitiveArray;
31use vortex_array::arrays::VarBinViewArray;
32use vortex_array::arrays::varbinview::build_views::BinaryView;
33use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
34use vortex_array::buffer::BufferHandle;
35use vortex_array::dtype::DType;
36use vortex_array::scalar::Scalar;
37use vortex_array::serde::ArrayChildren;
38use vortex_array::validity::Validity;
39use vortex_array::vtable::OperationsVTable;
40use vortex_array::vtable::VTable;
41use vortex_array::vtable::ValidityVTable;
42use vortex_array::vtable::child_to_validity;
43use vortex_array::vtable::validity_to_child;
44use vortex_buffer::Alignment;
45use vortex_buffer::Buffer;
46use vortex_buffer::BufferMut;
47use vortex_buffer::ByteBuffer;
48use vortex_buffer::ByteBufferMut;
49use vortex_error::VortexError;
50use vortex_error::VortexExpect;
51use vortex_error::VortexResult;
52use vortex_error::vortex_bail;
53use vortex_error::vortex_ensure;
54use vortex_error::vortex_err;
55use vortex_error::vortex_panic;
56use vortex_mask::AllOr;
57use vortex_session::VortexSession;
58use vortex_session::registry::CachedId;
59
60use crate::ZstdFrameMetadata;
61use crate::ZstdMetadata;
62
63// Zstd doesn't support training dictionaries on very few samples.
64const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
65type ViewLen = u32;
66
67// Overall approach here:
68// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
69// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
70// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
71// want somewhat faster access to slices or individual rows, allowing us to only
72// decompress the necessary frames.
73
74// Visually, during decompression, we have an interval of frames we're
75// decompressing and a tighter interval of the slice we actually care about.
76// |=============values (all valid elements)==============|
77// |<-skipped_uncompressed->|----decompressed-------------|
78//                              |------slice-------|
79//                              ^                  ^
80// |<-slice_uncompressed_start->|                  |
81// |<------------slice_uncompressed_stop---------->|
82// We then insert these values to the correct position using a primitive array
83// constructor.
84
85/// A [`Zstd`]-encoded Vortex array.
86pub type ZstdArray = Array<Zstd>;
87
88impl ArrayHash for ZstdData {
89    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
90        match &self.dictionary {
91            Some(dict) => {
92                true.hash(state);
93                dict.array_hash(state, precision);
94            }
95            None => {
96                false.hash(state);
97            }
98        }
99        for frame in &self.frames {
100            frame.array_hash(state, precision);
101        }
102        self.unsliced_n_rows.hash(state);
103        self.slice_start.hash(state);
104        self.slice_stop.hash(state);
105    }
106}
107
108impl ArrayEq for ZstdData {
109    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
110        if !match (&self.dictionary, &other.dictionary) {
111            (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
112            (None, None) => true,
113            _ => false,
114        } {
115            return false;
116        }
117        if self.frames.len() != other.frames.len() {
118            return false;
119        }
120        for (a, b) in self.frames.iter().zip(&other.frames) {
121            if !a.array_eq(b, precision) {
122                return false;
123            }
124        }
125        self.unsliced_n_rows == other.unsliced_n_rows
126            && self.slice_start == other.slice_start
127            && self.slice_stop == other.slice_stop
128    }
129}
130
131impl VTable for Zstd {
132    type ArrayData = ZstdData;
133
134    type OperationsVTable = Self;
135    type ValidityVTable = Self;
136
137    fn id(&self) -> ArrayId {
138        static ID: CachedId = CachedId::new("vortex.zstd");
139        *ID
140    }
141
142    fn validate(
143        &self,
144        data: &Self::ArrayData,
145        dtype: &DType,
146        len: usize,
147        slots: &[Option<ArrayRef>],
148    ) -> VortexResult<()> {
149        let validity = child_to_validity(&slots[0], dtype.nullability());
150        data.validate(dtype, len, &validity)
151    }
152
153    fn nbuffers(array: ArrayView<'_, Self>) -> usize {
154        array.dictionary.is_some() as usize + array.frames.len()
155    }
156
157    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
158        if let Some(dict) = &array.dictionary {
159            if idx == 0 {
160                return BufferHandle::new_host(dict.clone());
161            }
162            BufferHandle::new_host(array.frames[idx - 1].clone())
163        } else {
164            BufferHandle::new_host(array.frames[idx].clone())
165        }
166    }
167
168    fn buffer_name(array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
169        if array.dictionary.is_some() {
170            if idx == 0 {
171                Some("dictionary".to_string())
172            } else {
173                Some(format!("frame_{}", idx - 1))
174            }
175        } else {
176            Some(format!("frame_{idx}"))
177        }
178    }
179
180    fn serialize(
181        array: ArrayView<'_, Self>,
182        _session: &VortexSession,
183    ) -> VortexResult<Option<Vec<u8>>> {
184        Ok(Some(array.metadata.clone().encode_to_vec()))
185    }
186
187    fn deserialize(
188        &self,
189        dtype: &DType,
190        len: usize,
191        metadata: &[u8],
192        buffers: &[BufferHandle],
193        children: &dyn ArrayChildren,
194        _session: &VortexSession,
195    ) -> VortexResult<ArrayParts<Self>> {
196        let metadata = ZstdMetadata::decode(metadata)?;
197        let validity = if children.is_empty() {
198            Validity::from(dtype.nullability())
199        } else if children.len() == 1 {
200            let validity = children.get(0, &Validity::DTYPE, len)?;
201            Validity::Array(validity)
202        } else {
203            vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
204        };
205
206        let (dictionary_buffer, compressed_buffers) = if metadata.dictionary_size == 0 {
207            // no dictionary
208            (
209                None,
210                buffers
211                    .iter()
212                    .map(|b| b.clone().try_to_host_sync())
213                    .collect::<VortexResult<Vec<_>>>()?,
214            )
215        } else {
216            // with dictionary
217            (
218                Some(buffers[0].clone().try_to_host_sync()?),
219                buffers[1..]
220                    .iter()
221                    .map(|b| b.clone().try_to_host_sync())
222                    .collect::<VortexResult<Vec<_>>>()?,
223            )
224        };
225
226        let slots = vec![validity_to_child(&validity, len)];
227        let data = ZstdData::new(dictionary_buffer, compressed_buffers, metadata, len);
228        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
229    }
230
231    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
232        SLOT_NAMES[idx].to_string()
233    }
234
235    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
236        let unsliced_validity =
237            child_to_validity(&array.as_ref().slots()[0], array.dtype().nullability());
238        array
239            .data()
240            .decompress(array.dtype(), &unsliced_validity, ctx)?
241            .execute::<ArrayRef>(ctx)
242            .map(ExecutionResult::done)
243    }
244
245    fn reduce_parent(
246        array: ArrayView<'_, Self>,
247        parent: &ArrayRef,
248        child_idx: usize,
249    ) -> VortexResult<Option<ArrayRef>> {
250        crate::rules::RULES.evaluate(array, parent, child_idx)
251    }
252}
253
254#[derive(Clone, Debug)]
255pub struct Zstd;
256
257impl Zstd {
258    pub fn try_new(dtype: DType, data: ZstdData, validity: Validity) -> VortexResult<ZstdArray> {
259        let len = data.len();
260        data.validate(&dtype, len, &validity)?;
261        let slots = vec![validity_to_child(&validity, data.unsliced_n_rows())];
262        Ok(unsafe {
263            Array::from_parts_unchecked(ArrayParts::new(Zstd, dtype, len, data).with_slots(slots))
264        })
265    }
266
267    /// Compress a [`VarBinViewArray`] using Zstd without a dictionary.
268    pub fn from_var_bin_view_without_dict(
269        vbv: &VarBinViewArray,
270        level: i32,
271        values_per_frame: usize,
272    ) -> VortexResult<ZstdArray> {
273        let validity = vbv.validity()?;
274        Self::try_new(
275            vbv.dtype().clone(),
276            ZstdData::from_var_bin_view_without_dict(vbv, level, values_per_frame)?,
277            validity,
278        )
279    }
280
281    /// Compress a [`PrimitiveArray`] using Zstd.
282    pub fn from_primitive(
283        parray: &PrimitiveArray,
284        level: i32,
285        values_per_frame: usize,
286    ) -> VortexResult<ZstdArray> {
287        let validity = parray.validity()?;
288        Self::try_new(
289            parray.dtype().clone(),
290            ZstdData::from_primitive(parray, level, values_per_frame)?,
291            validity,
292        )
293    }
294
295    /// Compress a [`VarBinViewArray`] using Zstd.
296    pub fn from_var_bin_view(
297        vbv: &VarBinViewArray,
298        level: i32,
299        values_per_frame: usize,
300    ) -> VortexResult<ZstdArray> {
301        let validity = vbv.validity()?;
302        Self::try_new(
303            vbv.dtype().clone(),
304            ZstdData::from_var_bin_view(vbv, level, values_per_frame)?,
305            validity,
306        )
307    }
308
309    pub fn decompress(array: &ZstdArray, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
310        let unsliced_validity =
311            child_to_validity(&array.as_ref().slots()[0], array.dtype().nullability());
312        array
313            .data()
314            .decompress(array.dtype(), &unsliced_validity, ctx)
315    }
316}
317
318/// The validity bitmap indicating which elements are non-null.
319pub(super) const NUM_SLOTS: usize = 1;
320pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
321
322#[derive(Clone, Debug)]
323pub struct ZstdData {
324    pub(crate) dictionary: Option<ByteBuffer>,
325    pub(crate) frames: Vec<ByteBuffer>,
326    pub(crate) metadata: ZstdMetadata,
327    unsliced_n_rows: usize,
328    slice_start: usize,
329    slice_stop: usize,
330}
331
332impl Display for ZstdData {
333    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
334        write!(
335            f,
336            "nrows: {}, slice: {}..{}",
337            self.unsliced_n_rows, self.slice_start, self.slice_stop
338        )
339    }
340}
341
342pub struct ZstdDataParts {
343    pub dictionary: Option<ByteBuffer>,
344    pub frames: Vec<ByteBuffer>,
345    pub metadata: ZstdMetadata,
346    pub validity: Validity,
347    pub n_rows: usize,
348    pub slice_start: usize,
349    pub slice_stop: usize,
350}
351
352/// The parts of a [`ZstdArray`] returned by [`ZstdArray::into_parts`].
353#[derive(Debug)]
354struct Frames {
355    dictionary: Option<ByteBuffer>,
356    frames: Vec<ByteBuffer>,
357    frame_metas: Vec<ZstdFrameMetadata>,
358}
359
360fn choose_max_dict_size(uncompressed_size: usize) -> usize {
361    // following recommendations from
362    // https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L190
363    // that is, 1/100 the data size, up to 100kB.
364    // It appears that zstd can't train dictionaries with <256 bytes.
365    (uncompressed_size / 100).clamp(256, 100 * 1024)
366}
367
368fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
369    let mask = parray.as_ref().validity()?.to_mask(
370        parray.as_ref().len(),
371        &mut LEGACY_SESSION.create_execution_ctx(),
372    )?;
373    Ok(parray.filter(mask)?.to_primitive())
374}
375
376fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
377    let mask = vbv.as_ref().validity()?.to_mask(
378        vbv.as_ref().len(),
379        &mut LEGACY_SESSION.create_execution_ctx(),
380    )?;
381    let buffer_and_value_byte_indices = match mask.bit_buffer() {
382        AllOr::None => (Buffer::empty(), Vec::new()),
383        _ => {
384            let mut buffer = BufferMut::with_capacity(
385                usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
386                    + mask.true_count() * size_of::<ViewLen>(),
387            );
388            let mut value_byte_indices = Vec::new();
389            vbv.with_iterator(|iterator| {
390                // by flattening, we should omit nulls
391                for value in iterator.flatten() {
392                    value_byte_indices.push(buffer.len());
393                    // here's where we write the string lengths
394                    buffer
395                        .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
396                    buffer.extend_from_slice(value);
397                }
398                Ok::<_, VortexError>(())
399            })?;
400            (buffer.freeze(), value_byte_indices)
401        }
402    };
403    Ok(buffer_and_value_byte_indices)
404}
405
406/// Reconstruct BinaryView structs from length-prefixed byte data.
407///
408/// The buffer contains interleaved u32 lengths (little-endian) and string data.
409/// When the cumulative data exceeds `max_buffer_len`, the buffer is split (zero-copy) into
410/// multiple segments so that BinaryView's u32 offsets can address all data.
411///
412/// Pass [`MAX_BUFFER_LEN`] for `max_buffer_len` in production; a smaller value can be used in
413/// tests to exercise the splitting path without allocating >2 GiB.
414pub fn reconstruct_views(
415    buffer: &ByteBuffer,
416    max_buffer_len: usize,
417) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
418    let mut views = BufferMut::<BinaryView>::empty();
419    let mut buffers = Vec::new();
420    let mut segment_start: usize = 0;
421    let mut offset = 0;
422
423    while offset < buffer.len() {
424        let str_len = ViewLen::from_le_bytes(
425            buffer
426                .get(offset..offset + size_of::<ViewLen>())
427                .vortex_expect("corrupted zstd length")
428                .try_into()
429                .ok()
430                .vortex_expect("must fit ViewLen size"),
431        ) as usize;
432
433        let value_data_offset = offset + size_of::<ViewLen>();
434        let local_offset = value_data_offset - segment_start;
435
436        if local_offset + str_len > max_buffer_len && offset > segment_start {
437            buffers.push(buffer.slice(segment_start..offset));
438            segment_start = offset;
439        }
440
441        let local_offset = u32::try_from(value_data_offset - segment_start)
442            .vortex_expect("local offset within segment must fit in u32");
443        let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
444        let value = &buffer[value_data_offset..value_data_offset + str_len];
445        views.push(BinaryView::make_view(value, buf_index, local_offset));
446        offset = value_data_offset + str_len;
447    }
448
449    if segment_start < buffer.len() {
450        buffers.push(buffer.slice(segment_start..buffer.len()));
451    }
452
453    (buffers, views.freeze())
454}
455
456impl ZstdData {
457    pub fn new(
458        dictionary: Option<ByteBuffer>,
459        frames: Vec<ByteBuffer>,
460        metadata: ZstdMetadata,
461        n_rows: usize,
462    ) -> Self {
463        Self {
464            dictionary,
465            frames,
466            metadata,
467            unsliced_n_rows: n_rows,
468            slice_start: 0,
469            slice_stop: n_rows,
470        }
471    }
472
473    pub fn validate(&self, dtype: &DType, len: usize, validity: &Validity) -> VortexResult<()> {
474        vortex_ensure!(
475            matches!(
476                dtype,
477                DType::Primitive(..) | DType::Binary(_) | DType::Utf8(_)
478            ),
479            "Unsupported dtype for Zstd array: {dtype}"
480        );
481        vortex_ensure!(
482            self.slice_start <= self.slice_stop,
483            "Invalid slice range {}..{}",
484            self.slice_start,
485            self.slice_stop
486        );
487        vortex_ensure!(
488            self.slice_stop <= self.unsliced_n_rows,
489            "Slice stop {} exceeds unsliced row count {}",
490            self.slice_stop,
491            self.unsliced_n_rows
492        );
493        vortex_ensure!(
494            self.slice_stop - self.slice_start == len,
495            "Slice length {} does not match array length {}",
496            self.slice_stop - self.slice_start,
497            len
498        );
499        if let Some(validity_len) = validity.maybe_len() {
500            vortex_ensure!(
501                validity_len == self.unsliced_n_rows,
502                "Validity length {} does not match unsliced row count {}",
503                validity_len,
504                self.unsliced_n_rows
505            );
506        }
507
508        match &self.dictionary {
509            Some(dictionary) => vortex_ensure!(
510                usize::try_from(self.metadata.dictionary_size)? == dictionary.len(),
511                "Dictionary size metadata {} does not match buffer size {}",
512                self.metadata.dictionary_size,
513                dictionary.len()
514            ),
515            None => vortex_ensure!(
516                self.metadata.dictionary_size == 0,
517                "Dictionary metadata present without dictionary buffer"
518            ),
519        }
520        vortex_ensure!(
521            self.frames.len() == self.metadata.frames.len(),
522            "Frame count {} does not match metadata frame count {}",
523            self.frames.len(),
524            self.metadata.frames.len()
525        );
526
527        Ok(())
528    }
529
530    pub(crate) fn with_slice(&self, start: usize, stop: usize) -> Self {
531        let new_start = self.slice_start + start;
532        let new_stop = self.slice_start + stop;
533
534        assert!(
535            new_start <= self.slice_stop,
536            "new slice start {new_start} exceeds end {}",
537            self.slice_stop
538        );
539
540        assert!(
541            new_stop <= self.slice_stop,
542            "new slice stop {new_stop} exceeds end {}",
543            self.slice_stop
544        );
545
546        Self {
547            slice_start: new_start,
548            slice_stop: new_stop,
549            ..self.clone()
550        }
551    }
552
553    fn compress_values(
554        value_bytes: &ByteBuffer,
555        frame_byte_starts: &[usize],
556        level: i32,
557        values_per_frame: usize,
558        n_values: usize,
559        use_dictionary: bool,
560    ) -> VortexResult<Frames> {
561        let n_frames = frame_byte_starts.len();
562
563        // Would-be sample sizes if we end up applying zstd dictionary
564        let mut sample_sizes = Vec::with_capacity(n_frames);
565        for i in 0..n_frames {
566            let frame_byte_end = frame_byte_starts
567                .get(i + 1)
568                .copied()
569                .unwrap_or(value_bytes.len());
570            sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
571        }
572        debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
573
574        let (dictionary, mut compressor) = if !use_dictionary
575            || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
576        {
577            // no dictionary
578            (None, zstd::bulk::Compressor::new(level)?)
579        } else {
580            // with dictionary
581            let max_dict_size = choose_max_dict_size(value_bytes.len());
582            let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
583                .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
584
585            let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
586            (Some(ByteBuffer::from(dict)), compressor)
587        };
588
589        let mut frame_metas = vec![];
590        let mut frames = vec![];
591        for i in 0..n_frames {
592            let frame_byte_end = frame_byte_starts
593                .get(i + 1)
594                .copied()
595                .unwrap_or(value_bytes.len());
596
597            let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
598            let compressed = compressor
599                .compress(uncompressed)
600                .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
601            frame_metas.push(ZstdFrameMetadata {
602                uncompressed_size: uncompressed.len() as u64,
603                n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
604            });
605            frames.push(ByteBuffer::from(compressed));
606        }
607
608        Ok(Frames {
609            dictionary,
610            frames,
611            frame_metas,
612        })
613    }
614
615    /// Creates a ZstdArray from a primitive array.
616    ///
617    /// # Arguments
618    /// * `parray` - The primitive array to compress
619    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
620    /// * `values_per_frame` - Number of values per frame (0 = single frame)
621    pub fn from_primitive(
622        parray: &PrimitiveArray,
623        level: i32,
624        values_per_frame: usize,
625    ) -> VortexResult<Self> {
626        Self::from_primitive_impl(parray, level, values_per_frame, true)
627    }
628
629    /// Creates a ZstdArray from a primitive array without using a dictionary.
630    ///
631    /// This is useful when the compressed data will be decompressed by systems
632    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
633    ///
634    /// Note: Without a dictionary, each frame is compressed independently.
635    /// Dictionaries are trained from sample data from previously seen frames,
636    /// to improve compression ratio.
637    ///
638    /// # Arguments
639    /// * `parray` - The primitive array to compress
640    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
641    /// * `values_per_frame` - Number of values per frame (0 = single frame)
642    pub fn from_primitive_without_dict(
643        parray: &PrimitiveArray,
644        level: i32,
645        values_per_frame: usize,
646    ) -> VortexResult<Self> {
647        Self::from_primitive_impl(parray, level, values_per_frame, false)
648    }
649
650    fn from_primitive_impl(
651        parray: &PrimitiveArray,
652        level: i32,
653        values_per_frame: usize,
654        use_dictionary: bool,
655    ) -> VortexResult<Self> {
656        let byte_width = parray.ptype().byte_width();
657
658        // We compress only the valid elements.
659        let values = collect_valid_primitive(parray)?;
660        let n_values = values.len();
661        let values_per_frame = if values_per_frame > 0 {
662            values_per_frame
663        } else {
664            n_values
665        };
666
667        let value_bytes = values.buffer_handle().try_to_host_sync()?;
668        // Align frames to buffer alignment. This is necessary for overaligned buffers.
669        let alignment = *value_bytes.alignment();
670        let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
671
672        let frame_byte_starts = (0..n_values * byte_width)
673            .step_by(step_width)
674            .collect::<Vec<_>>();
675        let Frames {
676            dictionary,
677            frames,
678            frame_metas,
679        } = Self::compress_values(
680            &value_bytes,
681            &frame_byte_starts,
682            level,
683            values_per_frame,
684            n_values,
685            use_dictionary,
686        )?;
687
688        let metadata = ZstdMetadata {
689            dictionary_size: dictionary
690                .as_ref()
691                .map_or(0, |dict| dict.len())
692                .try_into()?,
693            frames: frame_metas,
694        };
695
696        Ok(ZstdData::new(dictionary, frames, metadata, parray.len()))
697    }
698
699    /// Creates a ZstdArray from a VarBinView array.
700    ///
701    /// # Arguments
702    /// * `vbv` - The VarBinView array to compress
703    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
704    /// * `values_per_frame` - Number of values per frame (0 = single frame)
705    pub fn from_var_bin_view(
706        vbv: &VarBinViewArray,
707        level: i32,
708        values_per_frame: usize,
709    ) -> VortexResult<Self> {
710        Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
711    }
712
713    /// Creates a ZstdArray from a VarBinView array without using a dictionary.
714    ///
715    /// This is useful when the compressed data will be decompressed by systems
716    /// that don't support ZSTD dictionaries (e.g., nvCOMP on GPU).
717    ///
718    /// Note: Without a dictionary, each frame is compressed independently.
719    /// Dictionaries are trained from sample data from previously seen frames,
720    /// to improve compression ratio.
721    ///
722    /// # Arguments
723    /// * `vbv` - The VarBinView array to compress
724    /// * `level` - Zstd compression level (0 = default, negative = fast, positive = better compression)
725    /// * `values_per_frame` - Number of values per frame (0 = single frame)
726    pub fn from_var_bin_view_without_dict(
727        vbv: &VarBinViewArray,
728        level: i32,
729        values_per_frame: usize,
730    ) -> VortexResult<Self> {
731        Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
732    }
733
734    fn from_var_bin_view_impl(
735        vbv: &VarBinViewArray,
736        level: i32,
737        values_per_frame: usize,
738        use_dictionary: bool,
739    ) -> VortexResult<Self> {
740        // Approach for strings: we prefix each string with its length as a u32.
741        // This is the same as what Parquet does. In some cases it may be better
742        // to separate the binary data and lengths as two separate streams, but
743        // this approach is simpler and can be best in cases when there is
744        // mutual information between strings and their lengths.
745        // We compress only the valid elements.
746        let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
747        let n_values = value_byte_indices.len();
748        let values_per_frame = if values_per_frame > 0 {
749            values_per_frame
750        } else {
751            n_values
752        };
753
754        let frame_byte_starts = (0..n_values)
755            .step_by(values_per_frame)
756            .map(|i| value_byte_indices[i])
757            .collect::<Vec<_>>();
758        let Frames {
759            dictionary,
760            frames,
761            frame_metas,
762        } = Self::compress_values(
763            &value_bytes,
764            &frame_byte_starts,
765            level,
766            values_per_frame,
767            n_values,
768            use_dictionary,
769        )?;
770
771        let metadata = ZstdMetadata {
772            dictionary_size: dictionary
773                .as_ref()
774                .map_or(0, |dict| dict.len())
775                .try_into()?,
776            frames: frame_metas,
777        };
778        Ok(ZstdData::new(dictionary, frames, metadata, vbv.len()))
779    }
780
781    pub fn from_canonical(
782        canonical: &Canonical,
783        level: i32,
784        values_per_frame: usize,
785    ) -> VortexResult<Option<Self>> {
786        match canonical {
787            Canonical::Primitive(parray) => Ok(Some(ZstdData::from_primitive(
788                parray,
789                level,
790                values_per_frame,
791            )?)),
792            Canonical::VarBinView(vbv) => Ok(Some(ZstdData::from_var_bin_view(
793                vbv,
794                level,
795                values_per_frame,
796            )?)),
797            _ => Ok(None),
798        }
799    }
800
801    pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
802        Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
803            .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
804    }
805
806    fn byte_width(dtype: &DType) -> usize {
807        if dtype.is_primitive() {
808            dtype.as_ptype().byte_width()
809        } else {
810            1
811        }
812    }
813
814    fn decompress(
815        &self,
816        dtype: &DType,
817        unsliced_validity: &Validity,
818        ctx: &mut ExecutionCtx,
819    ) -> VortexResult<ArrayRef> {
820        // To start, we figure out which frames we need to decompress, and with
821        // what row offset into the first such frame.
822        let byte_width = Self::byte_width(dtype);
823        let slice_n_rows = self.slice_stop - self.slice_start;
824        let slice_value_indices = unsliced_validity
825            .execute_mask(self.unsliced_n_rows, ctx)?
826            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
827
828        let slice_value_idx_start = slice_value_indices[0];
829        let slice_value_idx_stop = slice_value_indices[1];
830
831        let mut frames_to_decompress = vec![];
832        let mut value_idx_start = 0;
833        let mut uncompressed_size_to_decompress = 0;
834        let mut n_skipped_values = 0;
835        for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
836            if value_idx_start >= slice_value_idx_stop {
837                break;
838            }
839
840            let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
841                .vortex_expect("Uncompressed size must fit in usize");
842            let frame_n_values = if frame_meta.n_values == 0 {
843                // possibly older primitive-only metadata that just didn't store this
844                frame_uncompressed_size / byte_width
845            } else {
846                usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
847            };
848
849            let value_idx_stop = value_idx_start + frame_n_values;
850            if value_idx_stop > slice_value_idx_start {
851                // we need this frame
852                frames_to_decompress.push(frame);
853                uncompressed_size_to_decompress += frame_uncompressed_size;
854            } else {
855                n_skipped_values += frame_n_values;
856            }
857            value_idx_start = value_idx_stop;
858        }
859
860        // then we actually decompress those frames
861        let mut decompressor = if let Some(dictionary) = &self.dictionary {
862            zstd::bulk::Decompressor::with_dictionary(dictionary)?
863        } else {
864            zstd::bulk::Decompressor::new()?
865        };
866        let mut decompressed = ByteBufferMut::with_capacity_aligned(
867            uncompressed_size_to_decompress,
868            Alignment::new(byte_width),
869        );
870        unsafe {
871            // safety: we immediately fill all bytes in the following loop,
872            // assuming our metadata's uncompressed size is correct
873            decompressed.set_len(uncompressed_size_to_decompress);
874        }
875        let mut uncompressed_start = 0;
876        for frame in frames_to_decompress {
877            let uncompressed_written = decompressor
878                .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
879            uncompressed_start += uncompressed_written;
880        }
881        if uncompressed_start != uncompressed_size_to_decompress {
882            vortex_panic!(
883                "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
884                uncompressed_size_to_decompress,
885                uncompressed_start
886            );
887        }
888
889        let decompressed = decompressed.freeze();
890        // Last, we slice the exact values requested out of the decompressed data.
891        let mut slice_validity = unsliced_validity.slice(self.slice_start..self.slice_stop)?;
892
893        // NOTE: this block handles setting the output type when the validity and DType disagree.
894        //
895        // ZSTD is a compact block compressor, meaning that null values are not stored inline in
896        // the data frames. A ZSTD Array that was initialized must always hold onto its full
897        // validity bitmap, even if sliced to only include non-null values.
898        //
899        // We ensure that the validity of the decompressed array ALWAYS matches the validity
900        // implied by the DType.
901        if !dtype.is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
902            assert!(
903                matches!(slice_validity, Validity::AllValid),
904                "ZSTD array expects to be non-nullable but there are nulls after decompression"
905            );
906
907            slice_validity = Validity::NonNullable;
908        } else if dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
909            slice_validity = Validity::AllValid;
910        }
911        //
912        // END OF IMPORTANT BLOCK
913        //
914
915        match dtype {
916            DType::Primitive(..) => {
917                let slice_values_buffer = decompressed.slice(
918                    (slice_value_idx_start - n_skipped_values) * byte_width
919                        ..(slice_value_idx_stop - n_skipped_values) * byte_width,
920                );
921                let primitive = PrimitiveArray::from_values_byte_buffer(
922                    slice_values_buffer,
923                    dtype.as_ptype(),
924                    slice_validity,
925                    slice_n_rows,
926                );
927
928                Ok(primitive.into_array())
929            }
930            DType::Binary(_) | DType::Utf8(_) => {
931                match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
932                    AllOr::All => {
933                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
934                        let valid_views = all_views.slice(
935                            slice_value_idx_start - n_skipped_values
936                                ..slice_value_idx_stop - n_skipped_values,
937                        );
938
939                        // SAFETY: we properly construct the views inside `reconstruct_views`
940                        Ok(unsafe {
941                            VarBinViewArray::new_unchecked(
942                                valid_views,
943                                Arc::from(buffers),
944                                dtype.clone(),
945                                slice_validity,
946                            )
947                        }
948                        .into_array())
949                    }
950                    AllOr::None => Ok(ConstantArray::new(
951                        Scalar::null(dtype.clone()),
952                        slice_n_rows,
953                    )
954                    .into_array()),
955                    AllOr::Some(valid_indices) => {
956                        let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
957                        let valid_views = all_views.slice(
958                            slice_value_idx_start - n_skipped_values
959                                ..slice_value_idx_stop - n_skipped_values,
960                        );
961
962                        let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
963                        for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
964                            views[*index] = view
965                        }
966
967                        // SAFETY: we properly construct the views inside `reconstruct_views`
968                        Ok(unsafe {
969                            VarBinViewArray::new_unchecked(
970                                views.freeze(),
971                                Arc::from(buffers),
972                                dtype.clone(),
973                                slice_validity,
974                            )
975                        }
976                        .into_array())
977                    }
978                }
979            }
980            _ => vortex_panic!("Unsupported dtype for Zstd array: {}", dtype),
981        }
982    }
983
984    /// Returns the length of the array.
985    #[inline]
986    pub fn len(&self) -> usize {
987        self.slice_stop - self.slice_start
988    }
989
990    /// Returns whether the array is empty.
991    #[inline]
992    pub fn is_empty(&self) -> bool {
993        self.slice_stop == self.slice_start
994    }
995
996    pub fn into_parts(self, validity: Validity) -> ZstdDataParts {
997        ZstdDataParts {
998            dictionary: self.dictionary,
999            frames: self.frames,
1000            metadata: self.metadata,
1001            validity,
1002            n_rows: self.unsliced_n_rows,
1003            slice_start: self.slice_start,
1004            slice_stop: self.slice_stop,
1005        }
1006    }
1007
1008    pub(crate) fn slice_start(&self) -> usize {
1009        self.slice_start
1010    }
1011
1012    pub(crate) fn slice_stop(&self) -> usize {
1013        self.slice_stop
1014    }
1015
1016    pub(crate) fn unsliced_n_rows(&self) -> usize {
1017        self.unsliced_n_rows
1018    }
1019}
1020
1021impl ValidityVTable<Zstd> for Zstd {
1022    fn validity(array: ArrayView<'_, Zstd>) -> VortexResult<Validity> {
1023        let unsliced_validity = child_to_validity(&array.slots()[0], array.dtype().nullability());
1024        unsliced_validity.slice(array.slice_start()..array.slice_stop())
1025    }
1026}
1027
1028impl OperationsVTable<Zstd> for Zstd {
1029    fn scalar_at(
1030        array: ArrayView<'_, Zstd>,
1031        index: usize,
1032        ctx: &mut ExecutionCtx,
1033    ) -> VortexResult<Scalar> {
1034        let unsliced_validity = child_to_validity(&array.slots()[0], array.dtype().nullability());
1035        let sliced = array.data().with_slice(index, index + 1);
1036        sliced
1037            .decompress(array.dtype(), &unsliced_validity, ctx)?
1038            .execute_scalar(0, ctx)
1039    }
1040}
1041
1042#[cfg(test)]
1043#[expect(clippy::cast_possible_truncation)]
1044mod tests {
1045    use vortex_buffer::ByteBuffer;
1046
1047    use super::reconstruct_views;
1048    use crate::array::BinaryView;
1049
1050    /// Build a Zstd-style interleaved buffer: [u32-LE length][string bytes] repeated.
1051    fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
1052        let mut buf = Vec::new();
1053        for s in strings {
1054            let len = s.len() as u32;
1055            buf.extend_from_slice(&len.to_le_bytes());
1056            buf.extend_from_slice(s);
1057        }
1058        ByteBuffer::copy_from(buf.as_slice())
1059    }
1060
1061    #[test]
1062    fn test_reconstruct_views_no_split() {
1063        let strings: &[&[u8]] = &[b"hello", b"world"];
1064        let buf = make_interleaved(strings);
1065        let (buffers, views) = reconstruct_views(&buf, 1024);
1066
1067        assert_eq!(buffers.len(), 1);
1068        assert_eq!(views.len(), 2);
1069        // Each entry: [u32 len (4 bytes)][data], so offsets are 4 and 4+5+4=13
1070        assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1071        assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1072    }
1073
1074    #[test]
1075    fn test_reconstruct_views_split_across_segments() {
1076        // "aaaaaaaaaaaaa" (13 bytes) and "bbbbbbbbbbbbb" (13 bytes).
1077        // Each entry occupies 4 (length prefix) + 13 (data) = 17 bytes.
1078        // With max_buffer_len=20, the second entry's data (offset 4+13+4=21) exceeds the limit,
1079        // so it rolls into a second segment.
1080        let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1081        let buf = make_interleaved(strings);
1082        let (buffers, views) = reconstruct_views(&buf, 20);
1083
1084        assert_eq!(buffers.len(), 2);
1085        assert_eq!(views.len(), 2);
1086        assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1087        // Second entry starts a new segment at byte 17 (the length prefix), so local offset = 4.
1088        assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1089    }
1090}