1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::DynArray;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionResult;
17use vortex_array::IntoArray;
18use vortex_array::LEGACY_SESSION;
19use vortex_array::Precision;
20use vortex_array::ProstMetadata;
21use vortex_array::ToCanonical;
22use vortex_array::VortexSessionExecute;
23use vortex_array::accessor::ArrayAccessor;
24use vortex_array::arrays::ConstantArray;
25use vortex_array::arrays::PrimitiveArray;
26use vortex_array::arrays::VarBinViewArray;
27use vortex_array::arrays::varbinview::build_views::BinaryView;
28use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
29use vortex_array::buffer::BufferHandle;
30use vortex_array::dtype::DType;
31use vortex_array::scalar::Scalar;
32use vortex_array::serde::ArrayChildren;
33use vortex_array::stats::ArrayStats;
34use vortex_array::stats::StatsSetRef;
35use vortex_array::validity::Validity;
36use vortex_array::vtable;
37use vortex_array::vtable::ArrayId;
38use vortex_array::vtable::OperationsVTable;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityHelper;
41use vortex_array::vtable::ValiditySliceHelper;
42use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
43use vortex_array::vtable::validity_nchildren;
44use vortex_array::vtable::validity_to_child;
45use vortex_buffer::Alignment;
46use vortex_buffer::Buffer;
47use vortex_buffer::BufferMut;
48use vortex_buffer::ByteBuffer;
49use vortex_buffer::ByteBufferMut;
50use vortex_error::VortexError;
51use vortex_error::VortexExpect;
52use vortex_error::VortexResult;
53use vortex_error::vortex_bail;
54use vortex_error::vortex_ensure;
55use vortex_error::vortex_err;
56use vortex_error::vortex_panic;
57use vortex_mask::AllOr;
58use vortex_session::VortexSession;
59
60use crate::ZstdFrameMetadata;
61use crate::ZstdMetadata;
62
63const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
65type ViewLen = u32;
66
67vtable!(Zstd);
86
87impl VTable for Zstd {
88 type Array = ZstdArray;
89
90 type Metadata = ProstMetadata<ZstdMetadata>;
91 type OperationsVTable = Self;
92 type ValidityVTable = ValidityVTableFromValiditySliceHelper;
93
94 fn vtable(_array: &Self::Array) -> &Self {
95 &Zstd
96 }
97
98 fn id(&self) -> ArrayId {
99 Self::ID
100 }
101
102 fn len(array: &ZstdArray) -> usize {
103 array.slice_stop - array.slice_start
104 }
105
106 fn dtype(array: &ZstdArray) -> &DType {
107 &array.dtype
108 }
109
110 fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
111 array.stats_set.to_ref(array.as_ref())
112 }
113
114 fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
115 match &array.dictionary {
116 Some(dict) => {
117 true.hash(state);
118 dict.array_hash(state, precision);
119 }
120 None => {
121 false.hash(state);
122 }
123 }
124 for frame in &array.frames {
125 frame.array_hash(state, precision);
126 }
127 array.dtype.hash(state);
128 array.unsliced_validity.array_hash(state, precision);
129 array.unsliced_n_rows.hash(state);
130 array.slice_start.hash(state);
131 array.slice_stop.hash(state);
132 }
133
134 fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
135 if !match (&array.dictionary, &other.dictionary) {
136 (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
137 (None, None) => true,
138 _ => false,
139 } {
140 return false;
141 }
142 if array.frames.len() != other.frames.len() {
143 return false;
144 }
145 for (a, b) in array.frames.iter().zip(&other.frames) {
146 if !a.array_eq(b, precision) {
147 return false;
148 }
149 }
150 array.dtype == other.dtype
151 && array
152 .unsliced_validity
153 .array_eq(&other.unsliced_validity, precision)
154 && array.unsliced_n_rows == other.unsliced_n_rows
155 && array.slice_start == other.slice_start
156 && array.slice_stop == other.slice_stop
157 }
158
159 fn nbuffers(array: &ZstdArray) -> usize {
160 array.dictionary.is_some() as usize + array.frames.len()
161 }
162
163 fn buffer(array: &ZstdArray, idx: usize) -> BufferHandle {
164 if let Some(dict) = &array.dictionary {
165 if idx == 0 {
166 return BufferHandle::new_host(dict.clone());
167 }
168 BufferHandle::new_host(array.frames[idx - 1].clone())
169 } else {
170 BufferHandle::new_host(array.frames[idx].clone())
171 }
172 }
173
174 fn buffer_name(array: &ZstdArray, idx: usize) -> Option<String> {
175 if array.dictionary.is_some() {
176 if idx == 0 {
177 Some("dictionary".to_string())
178 } else {
179 Some(format!("frame_{}", idx - 1))
180 }
181 } else {
182 Some(format!("frame_{idx}"))
183 }
184 }
185
186 fn nchildren(array: &ZstdArray) -> usize {
187 validity_nchildren(&array.unsliced_validity)
188 }
189
190 fn child(array: &ZstdArray, idx: usize) -> ArrayRef {
191 validity_to_child(&array.unsliced_validity, array.unsliced_n_rows)
192 .unwrap_or_else(|| vortex_panic!("ZstdArray child index {idx} out of bounds"))
193 }
194
195 fn child_name(_array: &ZstdArray, idx: usize) -> String {
196 match idx {
197 0 => "validity".to_string(),
198 _ => vortex_panic!("ZstdArray child_name index {idx} out of bounds"),
199 }
200 }
201
202 fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
203 Ok(ProstMetadata(array.metadata.clone()))
204 }
205
206 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
207 Ok(Some(metadata.0.encode_to_vec()))
208 }
209
210 fn deserialize(
211 bytes: &[u8],
212 _dtype: &DType,
213 _len: usize,
214 _buffers: &[BufferHandle],
215 _session: &VortexSession,
216 ) -> VortexResult<Self::Metadata> {
217 Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
218 }
219
220 fn build(
221 dtype: &DType,
222 len: usize,
223 metadata: &Self::Metadata,
224 buffers: &[BufferHandle],
225 children: &dyn ArrayChildren,
226 ) -> VortexResult<ZstdArray> {
227 let validity = if children.is_empty() {
228 Validity::from(dtype.nullability())
229 } else if children.len() == 1 {
230 let validity = children.get(0, &Validity::DTYPE, len)?;
231 Validity::Array(validity)
232 } else {
233 vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
234 };
235
236 let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
237 (
239 None,
240 buffers
241 .iter()
242 .map(|b| b.clone().try_to_host_sync())
243 .collect::<VortexResult<Vec<_>>>()?,
244 )
245 } else {
246 (
248 Some(buffers[0].clone().try_to_host_sync()?),
249 buffers[1..]
250 .iter()
251 .map(|b| b.clone().try_to_host_sync())
252 .collect::<VortexResult<Vec<_>>>()?,
253 )
254 };
255
256 Ok(ZstdArray::new(
257 dictionary_buffer,
258 compressed_buffers,
259 dtype.clone(),
260 metadata.0.clone(),
261 len,
262 validity,
263 ))
264 }
265
266 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
267 vortex_ensure!(
268 children.len() <= 1,
269 "ZstdArray expects at most 1 child (validity), got {}",
270 children.len()
271 );
272
273 array.unsliced_validity = if children.is_empty() {
274 Validity::from(array.dtype.nullability())
275 } else {
276 Validity::Array(children.into_iter().next().vortex_expect("checked"))
277 };
278
279 Ok(())
280 }
281
282 fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
283 array
284 .decompress(ctx)?
285 .execute::<ArrayRef>(ctx)
286 .map(ExecutionResult::done)
287 }
288
289 fn reduce_parent(
290 array: &Self::Array,
291 parent: &ArrayRef,
292 child_idx: usize,
293 ) -> VortexResult<Option<ArrayRef>> {
294 crate::rules::RULES.evaluate(array, parent, child_idx)
295 }
296}
297
298#[derive(Clone, Debug)]
299pub struct Zstd;
300
301impl Zstd {
302 pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
303}
304
305#[derive(Clone, Debug)]
306pub struct ZstdArray {
307 pub(crate) dictionary: Option<ByteBuffer>,
308 pub(crate) frames: Vec<ByteBuffer>,
309 pub(crate) metadata: ZstdMetadata,
310 dtype: DType,
311 pub(crate) unsliced_validity: Validity,
312 unsliced_n_rows: usize,
313 stats_set: ArrayStats,
314 slice_start: usize,
315 slice_stop: usize,
316}
317
318#[derive(Debug)]
320pub struct ZstdArrayParts {
321 pub dictionary: Option<ByteBuffer>,
323 pub frames: Vec<ByteBuffer>,
325 pub metadata: ZstdMetadata,
327 pub dtype: DType,
329 pub validity: Validity,
331 pub n_rows: usize,
333 pub slice_start: usize,
335 pub slice_stop: usize,
337}
338
339struct Frames {
340 dictionary: Option<ByteBuffer>,
341 frames: Vec<ByteBuffer>,
342 frame_metas: Vec<ZstdFrameMetadata>,
343}
344
345fn choose_max_dict_size(uncompressed_size: usize) -> usize {
346 (uncompressed_size / 100).clamp(256, 100 * 1024)
351}
352
353fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
354 let mask = parray.validity_mask()?;
355 Ok(parray.clone().into_array().filter(mask)?.to_primitive())
356}
357
358fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
359 let mask = vbv.validity_mask()?;
360 let buffer_and_value_byte_indices = match mask.bit_buffer() {
361 AllOr::None => (Buffer::empty(), Vec::new()),
362 _ => {
363 let mut buffer = BufferMut::with_capacity(
364 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
365 + mask.true_count() * size_of::<ViewLen>(),
366 );
367 let mut value_byte_indices = Vec::new();
368 vbv.with_iterator(|iterator| {
369 for value in iterator.flatten() {
371 value_byte_indices.push(buffer.len());
372 buffer
374 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
375 buffer.extend_from_slice(value);
376 }
377 Ok::<_, VortexError>(())
378 })?;
379 (buffer.freeze(), value_byte_indices)
380 }
381 };
382 Ok(buffer_and_value_byte_indices)
383}
384
385pub fn reconstruct_views(
394 buffer: &ByteBuffer,
395 max_buffer_len: usize,
396) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
397 let mut views = BufferMut::<BinaryView>::empty();
398 let mut buffers = Vec::new();
399 let mut segment_start: usize = 0;
400 let mut offset = 0;
401
402 while offset < buffer.len() {
403 let str_len = ViewLen::from_le_bytes(
404 buffer
405 .get(offset..offset + size_of::<ViewLen>())
406 .vortex_expect("corrupted zstd length")
407 .try_into()
408 .ok()
409 .vortex_expect("must fit ViewLen size"),
410 ) as usize;
411
412 let value_data_offset = offset + size_of::<ViewLen>();
413 let local_offset = value_data_offset - segment_start;
414
415 if local_offset + str_len > max_buffer_len && offset > segment_start {
416 buffers.push(buffer.slice(segment_start..offset));
417 segment_start = offset;
418 }
419
420 let local_offset = u32::try_from(value_data_offset - segment_start)
421 .vortex_expect("local offset within segment must fit in u32");
422 let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
423 let value = &buffer[value_data_offset..value_data_offset + str_len];
424 views.push(BinaryView::make_view(value, buf_index, local_offset));
425 offset = value_data_offset + str_len;
426 }
427
428 if segment_start < buffer.len() {
429 buffers.push(buffer.slice(segment_start..buffer.len()));
430 }
431
432 (buffers, views.freeze())
433}
434
435impl ZstdArray {
436 pub fn new(
437 dictionary: Option<ByteBuffer>,
438 frames: Vec<ByteBuffer>,
439 dtype: DType,
440 metadata: ZstdMetadata,
441 n_rows: usize,
442 validity: Validity,
443 ) -> Self {
444 Self {
445 dictionary,
446 frames,
447 metadata,
448 dtype,
449 unsliced_validity: validity,
450 unsliced_n_rows: n_rows,
451 stats_set: Default::default(),
452 slice_start: 0,
453 slice_stop: n_rows,
454 }
455 }
456
457 fn compress_values(
458 value_bytes: &ByteBuffer,
459 frame_byte_starts: &[usize],
460 level: i32,
461 values_per_frame: usize,
462 n_values: usize,
463 use_dictionary: bool,
464 ) -> VortexResult<Frames> {
465 let n_frames = frame_byte_starts.len();
466
467 let mut sample_sizes = Vec::with_capacity(n_frames);
469 for i in 0..n_frames {
470 let frame_byte_end = frame_byte_starts
471 .get(i + 1)
472 .copied()
473 .unwrap_or(value_bytes.len());
474 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
475 }
476 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
477
478 let (dictionary, mut compressor) = if !use_dictionary
479 || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
480 {
481 (None, zstd::bulk::Compressor::new(level)?)
483 } else {
484 let max_dict_size = choose_max_dict_size(value_bytes.len());
486 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
487 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
488
489 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
490 (Some(ByteBuffer::from(dict)), compressor)
491 };
492
493 let mut frame_metas = vec![];
494 let mut frames = vec![];
495 for i in 0..n_frames {
496 let frame_byte_end = frame_byte_starts
497 .get(i + 1)
498 .copied()
499 .unwrap_or(value_bytes.len());
500
501 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
502 let compressed = compressor
503 .compress(uncompressed)
504 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
505 frame_metas.push(ZstdFrameMetadata {
506 uncompressed_size: uncompressed.len() as u64,
507 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
508 });
509 frames.push(ByteBuffer::from(compressed));
510 }
511
512 Ok(Frames {
513 dictionary,
514 frames,
515 frame_metas,
516 })
517 }
518
519 pub fn from_primitive(
526 parray: &PrimitiveArray,
527 level: i32,
528 values_per_frame: usize,
529 ) -> VortexResult<Self> {
530 Self::from_primitive_impl(parray, level, values_per_frame, true)
531 }
532
533 pub fn from_primitive_without_dict(
547 parray: &PrimitiveArray,
548 level: i32,
549 values_per_frame: usize,
550 ) -> VortexResult<Self> {
551 Self::from_primitive_impl(parray, level, values_per_frame, false)
552 }
553
554 fn from_primitive_impl(
555 parray: &PrimitiveArray,
556 level: i32,
557 values_per_frame: usize,
558 use_dictionary: bool,
559 ) -> VortexResult<Self> {
560 let dtype = parray.dtype().clone();
561 let byte_width = parray.ptype().byte_width();
562
563 let values = collect_valid_primitive(parray)?;
565 let n_values = values.len();
566 let values_per_frame = if values_per_frame > 0 {
567 values_per_frame
568 } else {
569 n_values
570 };
571
572 let value_bytes = values.buffer_handle().try_to_host_sync()?;
573 let alignment = *value_bytes.alignment();
575 let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
576
577 let frame_byte_starts = (0..n_values * byte_width)
578 .step_by(step_width)
579 .collect::<Vec<_>>();
580 let Frames {
581 dictionary,
582 frames,
583 frame_metas,
584 } = Self::compress_values(
585 &value_bytes,
586 &frame_byte_starts,
587 level,
588 values_per_frame,
589 n_values,
590 use_dictionary,
591 )?;
592
593 let metadata = ZstdMetadata {
594 dictionary_size: dictionary
595 .as_ref()
596 .map_or(0, |dict| dict.len())
597 .try_into()?,
598 frames: frame_metas,
599 };
600
601 Ok(ZstdArray::new(
602 dictionary,
603 frames,
604 dtype,
605 metadata,
606 parray.len(),
607 parray.validity().clone(),
608 ))
609 }
610
611 pub fn from_var_bin_view(
618 vbv: &VarBinViewArray,
619 level: i32,
620 values_per_frame: usize,
621 ) -> VortexResult<Self> {
622 Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
623 }
624
625 pub fn from_var_bin_view_without_dict(
639 vbv: &VarBinViewArray,
640 level: i32,
641 values_per_frame: usize,
642 ) -> VortexResult<Self> {
643 Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
644 }
645
646 fn from_var_bin_view_impl(
647 vbv: &VarBinViewArray,
648 level: i32,
649 values_per_frame: usize,
650 use_dictionary: bool,
651 ) -> VortexResult<Self> {
652 let dtype = vbv.dtype().clone();
658
659 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
661 let n_values = value_byte_indices.len();
662 let values_per_frame = if values_per_frame > 0 {
663 values_per_frame
664 } else {
665 n_values
666 };
667
668 let frame_byte_starts = (0..n_values)
669 .step_by(values_per_frame)
670 .map(|i| value_byte_indices[i])
671 .collect::<Vec<_>>();
672 let Frames {
673 dictionary,
674 frames,
675 frame_metas,
676 } = Self::compress_values(
677 &value_bytes,
678 &frame_byte_starts,
679 level,
680 values_per_frame,
681 n_values,
682 use_dictionary,
683 )?;
684
685 let metadata = ZstdMetadata {
686 dictionary_size: dictionary
687 .as_ref()
688 .map_or(0, |dict| dict.len())
689 .try_into()?,
690 frames: frame_metas,
691 };
692 Ok(ZstdArray::new(
693 dictionary,
694 frames,
695 dtype,
696 metadata,
697 vbv.len(),
698 vbv.validity().clone(),
699 ))
700 }
701
702 pub fn from_canonical(
703 canonical: &Canonical,
704 level: i32,
705 values_per_frame: usize,
706 ) -> VortexResult<Option<Self>> {
707 match canonical {
708 Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
709 parray,
710 level,
711 values_per_frame,
712 )?)),
713 Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
714 vbv,
715 level,
716 values_per_frame,
717 )?)),
718 _ => Ok(None),
719 }
720 }
721
722 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
723 Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
724 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
725 }
726
727 fn byte_width(&self) -> usize {
728 if self.dtype.is_primitive() {
729 self.dtype.as_ptype().byte_width()
730 } else {
731 1
732 }
733 }
734
735 pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
736 let byte_width = self.byte_width();
739 let slice_n_rows = self.slice_stop - self.slice_start;
740 let slice_value_indices = self
741 .unsliced_validity
742 .execute_mask(self.unsliced_n_rows, ctx)?
743 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
744
745 let slice_value_idx_start = slice_value_indices[0];
746 let slice_value_idx_stop = slice_value_indices[1];
747
748 let mut frames_to_decompress = vec![];
749 let mut value_idx_start = 0;
750 let mut uncompressed_size_to_decompress = 0;
751 let mut n_skipped_values = 0;
752 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
753 if value_idx_start >= slice_value_idx_stop {
754 break;
755 }
756
757 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
758 .vortex_expect("Uncompressed size must fit in usize");
759 let frame_n_values = if frame_meta.n_values == 0 {
760 frame_uncompressed_size / byte_width
762 } else {
763 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
764 };
765
766 let value_idx_stop = value_idx_start + frame_n_values;
767 if value_idx_stop > slice_value_idx_start {
768 frames_to_decompress.push(frame);
770 uncompressed_size_to_decompress += frame_uncompressed_size;
771 } else {
772 n_skipped_values += frame_n_values;
773 }
774 value_idx_start = value_idx_stop;
775 }
776
777 let mut decompressor = if let Some(dictionary) = &self.dictionary {
779 zstd::bulk::Decompressor::with_dictionary(dictionary)?
780 } else {
781 zstd::bulk::Decompressor::new()?
782 };
783 let mut decompressed = ByteBufferMut::with_capacity_aligned(
784 uncompressed_size_to_decompress,
785 Alignment::new(byte_width),
786 );
787 unsafe {
788 decompressed.set_len(uncompressed_size_to_decompress);
791 }
792 let mut uncompressed_start = 0;
793 for frame in frames_to_decompress {
794 let uncompressed_written = decompressor
795 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
796 uncompressed_start += uncompressed_written;
797 }
798 if uncompressed_start != uncompressed_size_to_decompress {
799 vortex_panic!(
800 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
801 uncompressed_size_to_decompress,
802 uncompressed_start
803 );
804 }
805
806 let decompressed = decompressed.freeze();
807 let mut slice_validity = self
809 .unsliced_validity
810 .slice(self.slice_start..self.slice_stop)?;
811
812 if !self.dtype().is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
821 assert!(
822 matches!(slice_validity, Validity::AllValid),
823 "ZSTD array expects to be non-nullable but there are nulls after decompression"
824 );
825
826 slice_validity = Validity::NonNullable;
827 } else if self.dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
828 slice_validity = Validity::AllValid;
829 }
830 match &self.dtype {
835 DType::Primitive(..) => {
836 let slice_values_buffer = decompressed.slice(
837 (slice_value_idx_start - n_skipped_values) * byte_width
838 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
839 );
840 let primitive = PrimitiveArray::from_values_byte_buffer(
841 slice_values_buffer,
842 self.dtype.as_ptype(),
843 slice_validity,
844 slice_n_rows,
845 );
846
847 Ok(primitive.into_array())
848 }
849 DType::Binary(_) | DType::Utf8(_) => {
850 match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
851 AllOr::All => {
852 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
853 let valid_views = all_views.slice(
854 slice_value_idx_start - n_skipped_values
855 ..slice_value_idx_stop - n_skipped_values,
856 );
857
858 Ok(unsafe {
860 VarBinViewArray::new_unchecked(
861 valid_views,
862 Arc::from(buffers),
863 self.dtype.clone(),
864 slice_validity,
865 )
866 }
867 .into_array())
868 }
869 AllOr::None => Ok(ConstantArray::new(
870 Scalar::null(self.dtype.clone()),
871 slice_n_rows,
872 )
873 .into_array()),
874 AllOr::Some(valid_indices) => {
875 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
876 let valid_views = all_views.slice(
877 slice_value_idx_start - n_skipped_values
878 ..slice_value_idx_stop - n_skipped_values,
879 );
880
881 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
882 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
883 views[*index] = view
884 }
885
886 Ok(unsafe {
888 VarBinViewArray::new_unchecked(
889 views.freeze(),
890 Arc::from(buffers),
891 self.dtype.clone(),
892 slice_validity,
893 )
894 }
895 .into_array())
896 }
897 }
898 }
899 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
900 }
901 }
902
903 pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
904 let new_start = self.slice_start + start;
905 let new_stop = self.slice_start + stop;
906
907 assert!(
908 new_start <= self.slice_stop,
909 "new slice start {new_start} exceeds end {}",
910 self.slice_stop
911 );
912
913 assert!(
914 new_stop <= self.slice_stop,
915 "new slice stop {new_stop} exceeds end {}",
916 self.slice_stop
917 );
918
919 ZstdArray {
920 slice_start: self.slice_start + start,
921 slice_stop: self.slice_start + stop,
922 stats_set: Default::default(),
923 ..self.clone()
924 }
925 }
926
927 pub fn into_parts(self) -> ZstdArrayParts {
929 ZstdArrayParts {
930 dictionary: self.dictionary,
931 frames: self.frames,
932 metadata: self.metadata,
933 dtype: self.dtype,
934 validity: self.unsliced_validity,
935 n_rows: self.unsliced_n_rows,
936 slice_start: self.slice_start,
937 slice_stop: self.slice_stop,
938 }
939 }
940
941 pub(crate) fn dtype(&self) -> &DType {
942 &self.dtype
943 }
944
945 pub(crate) fn slice_start(&self) -> usize {
946 self.slice_start
947 }
948
949 pub(crate) fn slice_stop(&self) -> usize {
950 self.slice_stop
951 }
952
953 pub(crate) fn unsliced_n_rows(&self) -> usize {
954 self.unsliced_n_rows
955 }
956}
957
958impl ValiditySliceHelper for ZstdArray {
959 fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
960 (&self.unsliced_validity, self.slice_start, self.slice_stop)
961 }
962}
963
964impl OperationsVTable<Zstd> for Zstd {
965 fn scalar_at(array: &ZstdArray, index: usize) -> VortexResult<Scalar> {
966 let mut ctx = LEGACY_SESSION.create_execution_ctx();
967 array
968 ._slice(index, index + 1)
969 .decompress(&mut ctx)?
970 .scalar_at(0)
971 }
972}
973
974#[cfg(test)]
975#[allow(clippy::cast_possible_truncation)]
976mod tests {
977 use vortex_buffer::ByteBuffer;
978
979 use super::reconstruct_views;
980 use crate::array::BinaryView;
981
982 fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
984 let mut buf = Vec::new();
985 for s in strings {
986 let len = s.len() as u32;
987 buf.extend_from_slice(&len.to_le_bytes());
988 buf.extend_from_slice(s);
989 }
990 ByteBuffer::copy_from(buf.as_slice())
991 }
992
993 #[test]
994 fn test_reconstruct_views_no_split() {
995 let strings: &[&[u8]] = &[b"hello", b"world"];
996 let buf = make_interleaved(strings);
997 let (buffers, views) = reconstruct_views(&buf, 1024);
998
999 assert_eq!(buffers.len(), 1);
1000 assert_eq!(views.len(), 2);
1001 assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1003 assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1004 }
1005
1006 #[test]
1007 fn test_reconstruct_views_split_across_segments() {
1008 let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1013 let buf = make_interleaved(strings);
1014 let (buffers, views) = reconstruct_views(&buf, 20);
1015
1016 assert_eq!(buffers.len(), 2);
1017 assert_eq!(views.len(), 2);
1018 assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1019 assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1021 }
1022}