1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::DynArray;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionResult;
17use vortex_array::IntoArray;
18use vortex_array::LEGACY_SESSION;
19use vortex_array::Precision;
20use vortex_array::ProstMetadata;
21use vortex_array::ToCanonical;
22use vortex_array::VortexSessionExecute;
23use vortex_array::accessor::ArrayAccessor;
24use vortex_array::arrays::ConstantArray;
25use vortex_array::arrays::PrimitiveArray;
26use vortex_array::arrays::VarBinViewArray;
27use vortex_array::arrays::varbinview::build_views::BinaryView;
28use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
29use vortex_array::buffer::BufferHandle;
30use vortex_array::dtype::DType;
31use vortex_array::scalar::Scalar;
32use vortex_array::serde::ArrayChildren;
33use vortex_array::stats::ArrayStats;
34use vortex_array::stats::StatsSetRef;
35use vortex_array::validity::Validity;
36use vortex_array::vtable;
37use vortex_array::vtable::Array;
38use vortex_array::vtable::ArrayId;
39use vortex_array::vtable::OperationsVTable;
40use vortex_array::vtable::VTable;
41use vortex_array::vtable::ValidityHelper;
42use vortex_array::vtable::ValiditySliceHelper;
43use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
44use vortex_array::vtable::validity_nchildren;
45use vortex_array::vtable::validity_to_child;
46use vortex_buffer::Alignment;
47use vortex_buffer::Buffer;
48use vortex_buffer::BufferMut;
49use vortex_buffer::ByteBuffer;
50use vortex_buffer::ByteBufferMut;
51use vortex_error::VortexError;
52use vortex_error::VortexExpect;
53use vortex_error::VortexResult;
54use vortex_error::vortex_bail;
55use vortex_error::vortex_ensure;
56use vortex_error::vortex_err;
57use vortex_error::vortex_panic;
58use vortex_mask::AllOr;
59use vortex_session::VortexSession;
60
61use crate::ZstdFrameMetadata;
62use crate::ZstdMetadata;
63
64const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
66type ViewLen = u32;
67
68vtable!(Zstd);
87
88impl VTable for Zstd {
89 type Array = ZstdArray;
90
91 type Metadata = ProstMetadata<ZstdMetadata>;
92 type OperationsVTable = Self;
93 type ValidityVTable = ValidityVTableFromValiditySliceHelper;
94
95 fn vtable(_array: &Self::Array) -> &Self {
96 &Zstd
97 }
98
99 fn id(&self) -> ArrayId {
100 Self::ID
101 }
102
103 fn len(array: &ZstdArray) -> usize {
104 array.slice_stop - array.slice_start
105 }
106
107 fn dtype(array: &ZstdArray) -> &DType {
108 &array.dtype
109 }
110
111 fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
112 array.stats_set.to_ref(array.as_ref())
113 }
114
115 fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
116 match &array.dictionary {
117 Some(dict) => {
118 true.hash(state);
119 dict.array_hash(state, precision);
120 }
121 None => {
122 false.hash(state);
123 }
124 }
125 for frame in &array.frames {
126 frame.array_hash(state, precision);
127 }
128 array.dtype.hash(state);
129 array.unsliced_validity.array_hash(state, precision);
130 array.unsliced_n_rows.hash(state);
131 array.slice_start.hash(state);
132 array.slice_stop.hash(state);
133 }
134
135 fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
136 if !match (&array.dictionary, &other.dictionary) {
137 (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
138 (None, None) => true,
139 _ => false,
140 } {
141 return false;
142 }
143 if array.frames.len() != other.frames.len() {
144 return false;
145 }
146 for (a, b) in array.frames.iter().zip(&other.frames) {
147 if !a.array_eq(b, precision) {
148 return false;
149 }
150 }
151 array.dtype == other.dtype
152 && array
153 .unsliced_validity
154 .array_eq(&other.unsliced_validity, precision)
155 && array.unsliced_n_rows == other.unsliced_n_rows
156 && array.slice_start == other.slice_start
157 && array.slice_stop == other.slice_stop
158 }
159
160 fn nbuffers(array: &ZstdArray) -> usize {
161 array.dictionary.is_some() as usize + array.frames.len()
162 }
163
164 fn buffer(array: &ZstdArray, idx: usize) -> BufferHandle {
165 if let Some(dict) = &array.dictionary {
166 if idx == 0 {
167 return BufferHandle::new_host(dict.clone());
168 }
169 BufferHandle::new_host(array.frames[idx - 1].clone())
170 } else {
171 BufferHandle::new_host(array.frames[idx].clone())
172 }
173 }
174
175 fn buffer_name(array: &ZstdArray, idx: usize) -> Option<String> {
176 if array.dictionary.is_some() {
177 if idx == 0 {
178 Some("dictionary".to_string())
179 } else {
180 Some(format!("frame_{}", idx - 1))
181 }
182 } else {
183 Some(format!("frame_{idx}"))
184 }
185 }
186
187 fn nchildren(array: &ZstdArray) -> usize {
188 validity_nchildren(&array.unsliced_validity)
189 }
190
191 fn child(array: &ZstdArray, idx: usize) -> ArrayRef {
192 validity_to_child(&array.unsliced_validity, array.unsliced_n_rows)
193 .unwrap_or_else(|| vortex_panic!("ZstdArray child index {idx} out of bounds"))
194 }
195
196 fn child_name(_array: &ZstdArray, idx: usize) -> String {
197 match idx {
198 0 => "validity".to_string(),
199 _ => vortex_panic!("ZstdArray child_name index {idx} out of bounds"),
200 }
201 }
202
203 fn metadata(array: &ZstdArray) -> VortexResult<Self::Metadata> {
204 Ok(ProstMetadata(array.metadata.clone()))
205 }
206
207 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
208 Ok(Some(metadata.0.encode_to_vec()))
209 }
210
211 fn deserialize(
212 bytes: &[u8],
213 _dtype: &DType,
214 _len: usize,
215 _buffers: &[BufferHandle],
216 _session: &VortexSession,
217 ) -> VortexResult<Self::Metadata> {
218 Ok(ProstMetadata(ZstdMetadata::decode(bytes)?))
219 }
220
221 fn build(
222 dtype: &DType,
223 len: usize,
224 metadata: &Self::Metadata,
225 buffers: &[BufferHandle],
226 children: &dyn ArrayChildren,
227 ) -> VortexResult<ZstdArray> {
228 let validity = if children.is_empty() {
229 Validity::from(dtype.nullability())
230 } else if children.len() == 1 {
231 let validity = children.get(0, &Validity::DTYPE, len)?;
232 Validity::Array(validity)
233 } else {
234 vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
235 };
236
237 let (dictionary_buffer, compressed_buffers) = if metadata.0.dictionary_size == 0 {
238 (
240 None,
241 buffers
242 .iter()
243 .map(|b| b.clone().try_to_host_sync())
244 .collect::<VortexResult<Vec<_>>>()?,
245 )
246 } else {
247 (
249 Some(buffers[0].clone().try_to_host_sync()?),
250 buffers[1..]
251 .iter()
252 .map(|b| b.clone().try_to_host_sync())
253 .collect::<VortexResult<Vec<_>>>()?,
254 )
255 };
256
257 Ok(ZstdArray::new(
258 dictionary_buffer,
259 compressed_buffers,
260 dtype.clone(),
261 metadata.0.clone(),
262 len,
263 validity,
264 ))
265 }
266
267 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
268 vortex_ensure!(
269 children.len() <= 1,
270 "ZstdArray expects at most 1 child (validity), got {}",
271 children.len()
272 );
273
274 array.unsliced_validity = if children.is_empty() {
275 Validity::from(array.dtype.nullability())
276 } else {
277 Validity::Array(children.into_iter().next().vortex_expect("checked"))
278 };
279
280 Ok(())
281 }
282
283 fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
284 array
285 .decompress(ctx)?
286 .execute::<ArrayRef>(ctx)
287 .map(ExecutionResult::done)
288 }
289
290 fn reduce_parent(
291 array: &Array<Self>,
292 parent: &ArrayRef,
293 child_idx: usize,
294 ) -> VortexResult<Option<ArrayRef>> {
295 crate::rules::RULES.evaluate(array, parent, child_idx)
296 }
297}
298
299#[derive(Clone, Debug)]
300pub struct Zstd;
301
302impl Zstd {
303 pub const ID: ArrayId = ArrayId::new_ref("vortex.zstd");
304}
305
306#[derive(Clone, Debug)]
307pub struct ZstdArray {
308 pub(crate) dictionary: Option<ByteBuffer>,
309 pub(crate) frames: Vec<ByteBuffer>,
310 pub(crate) metadata: ZstdMetadata,
311 dtype: DType,
312 pub(crate) unsliced_validity: Validity,
313 unsliced_n_rows: usize,
314 stats_set: ArrayStats,
315 slice_start: usize,
316 slice_stop: usize,
317}
318
319#[derive(Debug)]
321pub struct ZstdArrayParts {
322 pub dictionary: Option<ByteBuffer>,
324 pub frames: Vec<ByteBuffer>,
326 pub metadata: ZstdMetadata,
328 pub dtype: DType,
330 pub validity: Validity,
332 pub n_rows: usize,
334 pub slice_start: usize,
336 pub slice_stop: usize,
338}
339
340struct Frames {
341 dictionary: Option<ByteBuffer>,
342 frames: Vec<ByteBuffer>,
343 frame_metas: Vec<ZstdFrameMetadata>,
344}
345
346fn choose_max_dict_size(uncompressed_size: usize) -> usize {
347 (uncompressed_size / 100).clamp(256, 100 * 1024)
352}
353
354fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
355 let mask = parray.validity_mask()?;
356 Ok(parray.clone().into_array().filter(mask)?.to_primitive())
357}
358
359fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
360 let mask = vbv.validity_mask()?;
361 let buffer_and_value_byte_indices = match mask.bit_buffer() {
362 AllOr::None => (Buffer::empty(), Vec::new()),
363 _ => {
364 let mut buffer = BufferMut::with_capacity(
365 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
366 + mask.true_count() * size_of::<ViewLen>(),
367 );
368 let mut value_byte_indices = Vec::new();
369 vbv.with_iterator(|iterator| {
370 for value in iterator.flatten() {
372 value_byte_indices.push(buffer.len());
373 buffer
375 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
376 buffer.extend_from_slice(value);
377 }
378 Ok::<_, VortexError>(())
379 })?;
380 (buffer.freeze(), value_byte_indices)
381 }
382 };
383 Ok(buffer_and_value_byte_indices)
384}
385
386pub fn reconstruct_views(
395 buffer: &ByteBuffer,
396 max_buffer_len: usize,
397) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
398 let mut views = BufferMut::<BinaryView>::empty();
399 let mut buffers = Vec::new();
400 let mut segment_start: usize = 0;
401 let mut offset = 0;
402
403 while offset < buffer.len() {
404 let str_len = ViewLen::from_le_bytes(
405 buffer
406 .get(offset..offset + size_of::<ViewLen>())
407 .vortex_expect("corrupted zstd length")
408 .try_into()
409 .ok()
410 .vortex_expect("must fit ViewLen size"),
411 ) as usize;
412
413 let value_data_offset = offset + size_of::<ViewLen>();
414 let local_offset = value_data_offset - segment_start;
415
416 if local_offset + str_len > max_buffer_len && offset > segment_start {
417 buffers.push(buffer.slice(segment_start..offset));
418 segment_start = offset;
419 }
420
421 let local_offset = u32::try_from(value_data_offset - segment_start)
422 .vortex_expect("local offset within segment must fit in u32");
423 let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
424 let value = &buffer[value_data_offset..value_data_offset + str_len];
425 views.push(BinaryView::make_view(value, buf_index, local_offset));
426 offset = value_data_offset + str_len;
427 }
428
429 if segment_start < buffer.len() {
430 buffers.push(buffer.slice(segment_start..buffer.len()));
431 }
432
433 (buffers, views.freeze())
434}
435
436impl ZstdArray {
437 pub fn new(
438 dictionary: Option<ByteBuffer>,
439 frames: Vec<ByteBuffer>,
440 dtype: DType,
441 metadata: ZstdMetadata,
442 n_rows: usize,
443 validity: Validity,
444 ) -> Self {
445 Self {
446 dictionary,
447 frames,
448 metadata,
449 dtype,
450 unsliced_validity: validity,
451 unsliced_n_rows: n_rows,
452 stats_set: Default::default(),
453 slice_start: 0,
454 slice_stop: n_rows,
455 }
456 }
457
458 fn compress_values(
459 value_bytes: &ByteBuffer,
460 frame_byte_starts: &[usize],
461 level: i32,
462 values_per_frame: usize,
463 n_values: usize,
464 use_dictionary: bool,
465 ) -> VortexResult<Frames> {
466 let n_frames = frame_byte_starts.len();
467
468 let mut sample_sizes = Vec::with_capacity(n_frames);
470 for i in 0..n_frames {
471 let frame_byte_end = frame_byte_starts
472 .get(i + 1)
473 .copied()
474 .unwrap_or(value_bytes.len());
475 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
476 }
477 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
478
479 let (dictionary, mut compressor) = if !use_dictionary
480 || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
481 {
482 (None, zstd::bulk::Compressor::new(level)?)
484 } else {
485 let max_dict_size = choose_max_dict_size(value_bytes.len());
487 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
488 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
489
490 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
491 (Some(ByteBuffer::from(dict)), compressor)
492 };
493
494 let mut frame_metas = vec![];
495 let mut frames = vec![];
496 for i in 0..n_frames {
497 let frame_byte_end = frame_byte_starts
498 .get(i + 1)
499 .copied()
500 .unwrap_or(value_bytes.len());
501
502 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
503 let compressed = compressor
504 .compress(uncompressed)
505 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
506 frame_metas.push(ZstdFrameMetadata {
507 uncompressed_size: uncompressed.len() as u64,
508 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
509 });
510 frames.push(ByteBuffer::from(compressed));
511 }
512
513 Ok(Frames {
514 dictionary,
515 frames,
516 frame_metas,
517 })
518 }
519
520 pub fn from_primitive(
527 parray: &PrimitiveArray,
528 level: i32,
529 values_per_frame: usize,
530 ) -> VortexResult<Self> {
531 Self::from_primitive_impl(parray, level, values_per_frame, true)
532 }
533
534 pub fn from_primitive_without_dict(
548 parray: &PrimitiveArray,
549 level: i32,
550 values_per_frame: usize,
551 ) -> VortexResult<Self> {
552 Self::from_primitive_impl(parray, level, values_per_frame, false)
553 }
554
555 fn from_primitive_impl(
556 parray: &PrimitiveArray,
557 level: i32,
558 values_per_frame: usize,
559 use_dictionary: bool,
560 ) -> VortexResult<Self> {
561 let dtype = parray.dtype().clone();
562 let byte_width = parray.ptype().byte_width();
563
564 let values = collect_valid_primitive(parray)?;
566 let n_values = values.len();
567 let values_per_frame = if values_per_frame > 0 {
568 values_per_frame
569 } else {
570 n_values
571 };
572
573 let value_bytes = values.buffer_handle().try_to_host_sync()?;
574 let alignment = *value_bytes.alignment();
576 let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
577
578 let frame_byte_starts = (0..n_values * byte_width)
579 .step_by(step_width)
580 .collect::<Vec<_>>();
581 let Frames {
582 dictionary,
583 frames,
584 frame_metas,
585 } = Self::compress_values(
586 &value_bytes,
587 &frame_byte_starts,
588 level,
589 values_per_frame,
590 n_values,
591 use_dictionary,
592 )?;
593
594 let metadata = ZstdMetadata {
595 dictionary_size: dictionary
596 .as_ref()
597 .map_or(0, |dict| dict.len())
598 .try_into()?,
599 frames: frame_metas,
600 };
601
602 Ok(ZstdArray::new(
603 dictionary,
604 frames,
605 dtype,
606 metadata,
607 parray.len(),
608 parray.validity().clone(),
609 ))
610 }
611
612 pub fn from_var_bin_view(
619 vbv: &VarBinViewArray,
620 level: i32,
621 values_per_frame: usize,
622 ) -> VortexResult<Self> {
623 Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
624 }
625
626 pub fn from_var_bin_view_without_dict(
640 vbv: &VarBinViewArray,
641 level: i32,
642 values_per_frame: usize,
643 ) -> VortexResult<Self> {
644 Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
645 }
646
647 fn from_var_bin_view_impl(
648 vbv: &VarBinViewArray,
649 level: i32,
650 values_per_frame: usize,
651 use_dictionary: bool,
652 ) -> VortexResult<Self> {
653 let dtype = vbv.dtype().clone();
659
660 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
662 let n_values = value_byte_indices.len();
663 let values_per_frame = if values_per_frame > 0 {
664 values_per_frame
665 } else {
666 n_values
667 };
668
669 let frame_byte_starts = (0..n_values)
670 .step_by(values_per_frame)
671 .map(|i| value_byte_indices[i])
672 .collect::<Vec<_>>();
673 let Frames {
674 dictionary,
675 frames,
676 frame_metas,
677 } = Self::compress_values(
678 &value_bytes,
679 &frame_byte_starts,
680 level,
681 values_per_frame,
682 n_values,
683 use_dictionary,
684 )?;
685
686 let metadata = ZstdMetadata {
687 dictionary_size: dictionary
688 .as_ref()
689 .map_or(0, |dict| dict.len())
690 .try_into()?,
691 frames: frame_metas,
692 };
693 Ok(ZstdArray::new(
694 dictionary,
695 frames,
696 dtype,
697 metadata,
698 vbv.len(),
699 vbv.validity().clone(),
700 ))
701 }
702
703 pub fn from_canonical(
704 canonical: &Canonical,
705 level: i32,
706 values_per_frame: usize,
707 ) -> VortexResult<Option<Self>> {
708 match canonical {
709 Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
710 parray,
711 level,
712 values_per_frame,
713 )?)),
714 Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
715 vbv,
716 level,
717 values_per_frame,
718 )?)),
719 _ => Ok(None),
720 }
721 }
722
723 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
724 Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
725 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
726 }
727
728 fn byte_width(&self) -> usize {
729 if self.dtype.is_primitive() {
730 self.dtype.as_ptype().byte_width()
731 } else {
732 1
733 }
734 }
735
736 pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
737 let byte_width = self.byte_width();
740 let slice_n_rows = self.slice_stop - self.slice_start;
741 let slice_value_indices = self
742 .unsliced_validity
743 .execute_mask(self.unsliced_n_rows, ctx)?
744 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
745
746 let slice_value_idx_start = slice_value_indices[0];
747 let slice_value_idx_stop = slice_value_indices[1];
748
749 let mut frames_to_decompress = vec![];
750 let mut value_idx_start = 0;
751 let mut uncompressed_size_to_decompress = 0;
752 let mut n_skipped_values = 0;
753 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
754 if value_idx_start >= slice_value_idx_stop {
755 break;
756 }
757
758 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
759 .vortex_expect("Uncompressed size must fit in usize");
760 let frame_n_values = if frame_meta.n_values == 0 {
761 frame_uncompressed_size / byte_width
763 } else {
764 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
765 };
766
767 let value_idx_stop = value_idx_start + frame_n_values;
768 if value_idx_stop > slice_value_idx_start {
769 frames_to_decompress.push(frame);
771 uncompressed_size_to_decompress += frame_uncompressed_size;
772 } else {
773 n_skipped_values += frame_n_values;
774 }
775 value_idx_start = value_idx_stop;
776 }
777
778 let mut decompressor = if let Some(dictionary) = &self.dictionary {
780 zstd::bulk::Decompressor::with_dictionary(dictionary)?
781 } else {
782 zstd::bulk::Decompressor::new()?
783 };
784 let mut decompressed = ByteBufferMut::with_capacity_aligned(
785 uncompressed_size_to_decompress,
786 Alignment::new(byte_width),
787 );
788 unsafe {
789 decompressed.set_len(uncompressed_size_to_decompress);
792 }
793 let mut uncompressed_start = 0;
794 for frame in frames_to_decompress {
795 let uncompressed_written = decompressor
796 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
797 uncompressed_start += uncompressed_written;
798 }
799 if uncompressed_start != uncompressed_size_to_decompress {
800 vortex_panic!(
801 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
802 uncompressed_size_to_decompress,
803 uncompressed_start
804 );
805 }
806
807 let decompressed = decompressed.freeze();
808 let mut slice_validity = self
810 .unsliced_validity
811 .slice(self.slice_start..self.slice_stop)?;
812
813 if !self.dtype().is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
822 assert!(
823 matches!(slice_validity, Validity::AllValid),
824 "ZSTD array expects to be non-nullable but there are nulls after decompression"
825 );
826
827 slice_validity = Validity::NonNullable;
828 } else if self.dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
829 slice_validity = Validity::AllValid;
830 }
831 match &self.dtype {
836 DType::Primitive(..) => {
837 let slice_values_buffer = decompressed.slice(
838 (slice_value_idx_start - n_skipped_values) * byte_width
839 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
840 );
841 let primitive = PrimitiveArray::from_values_byte_buffer(
842 slice_values_buffer,
843 self.dtype.as_ptype(),
844 slice_validity,
845 slice_n_rows,
846 );
847
848 Ok(primitive.into_array())
849 }
850 DType::Binary(_) | DType::Utf8(_) => {
851 match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
852 AllOr::All => {
853 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
854 let valid_views = all_views.slice(
855 slice_value_idx_start - n_skipped_values
856 ..slice_value_idx_stop - n_skipped_values,
857 );
858
859 Ok(unsafe {
861 VarBinViewArray::new_unchecked(
862 valid_views,
863 Arc::from(buffers),
864 self.dtype.clone(),
865 slice_validity,
866 )
867 }
868 .into_array())
869 }
870 AllOr::None => Ok(ConstantArray::new(
871 Scalar::null(self.dtype.clone()),
872 slice_n_rows,
873 )
874 .into_array()),
875 AllOr::Some(valid_indices) => {
876 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
877 let valid_views = all_views.slice(
878 slice_value_idx_start - n_skipped_values
879 ..slice_value_idx_stop - n_skipped_values,
880 );
881
882 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
883 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
884 views[*index] = view
885 }
886
887 Ok(unsafe {
889 VarBinViewArray::new_unchecked(
890 views.freeze(),
891 Arc::from(buffers),
892 self.dtype.clone(),
893 slice_validity,
894 )
895 }
896 .into_array())
897 }
898 }
899 }
900 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
901 }
902 }
903
904 pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
905 let new_start = self.slice_start + start;
906 let new_stop = self.slice_start + stop;
907
908 assert!(
909 new_start <= self.slice_stop,
910 "new slice start {new_start} exceeds end {}",
911 self.slice_stop
912 );
913
914 assert!(
915 new_stop <= self.slice_stop,
916 "new slice stop {new_stop} exceeds end {}",
917 self.slice_stop
918 );
919
920 ZstdArray {
921 slice_start: self.slice_start + start,
922 slice_stop: self.slice_start + stop,
923 stats_set: Default::default(),
924 ..self.clone()
925 }
926 }
927
928 pub fn into_parts(self) -> ZstdArrayParts {
930 ZstdArrayParts {
931 dictionary: self.dictionary,
932 frames: self.frames,
933 metadata: self.metadata,
934 dtype: self.dtype,
935 validity: self.unsliced_validity,
936 n_rows: self.unsliced_n_rows,
937 slice_start: self.slice_start,
938 slice_stop: self.slice_stop,
939 }
940 }
941
942 pub(crate) fn dtype(&self) -> &DType {
943 &self.dtype
944 }
945
946 pub(crate) fn slice_start(&self) -> usize {
947 self.slice_start
948 }
949
950 pub(crate) fn slice_stop(&self) -> usize {
951 self.slice_stop
952 }
953
954 pub(crate) fn unsliced_n_rows(&self) -> usize {
955 self.unsliced_n_rows
956 }
957}
958
959impl ValiditySliceHelper for ZstdArray {
960 fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
961 (&self.unsliced_validity, self.slice_start, self.slice_stop)
962 }
963}
964
965impl OperationsVTable<Zstd> for Zstd {
966 fn scalar_at(array: &ZstdArray, index: usize, _ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
967 let mut ctx = LEGACY_SESSION.create_execution_ctx();
968 array
969 ._slice(index, index + 1)
970 .decompress(&mut ctx)?
971 .scalar_at(0)
972 }
973}
974
975#[cfg(test)]
976#[allow(clippy::cast_possible_truncation)]
977mod tests {
978 use vortex_buffer::ByteBuffer;
979
980 use super::reconstruct_views;
981 use crate::array::BinaryView;
982
983 fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
985 let mut buf = Vec::new();
986 for s in strings {
987 let len = s.len() as u32;
988 buf.extend_from_slice(&len.to_le_bytes());
989 buf.extend_from_slice(s);
990 }
991 ByteBuffer::copy_from(buf.as_slice())
992 }
993
994 #[test]
995 fn test_reconstruct_views_no_split() {
996 let strings: &[&[u8]] = &[b"hello", b"world"];
997 let buf = make_interleaved(strings);
998 let (buffers, views) = reconstruct_views(&buf, 1024);
999
1000 assert_eq!(buffers.len(), 1);
1001 assert_eq!(views.len(), 2);
1002 assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1004 assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1005 }
1006
1007 #[test]
1008 fn test_reconstruct_views_split_across_segments() {
1009 let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1014 let buf = make_interleaved(strings);
1015 let (buffers, views) = reconstruct_views(&buf, 20);
1016
1017 assert_eq!(buffers.len(), 2);
1018 assert_eq!(views.len(), 2);
1019 assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1020 assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1022 }
1023}