1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use itertools::Itertools as _;
12use prost::Message as _;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayId;
17use vortex_array::ArrayParts;
18use vortex_array::ArrayRef;
19use vortex_array::ArrayView;
20use vortex_array::Canonical;
21use vortex_array::ExecutionCtx;
22use vortex_array::ExecutionResult;
23use vortex_array::IntoArray;
24use vortex_array::LEGACY_SESSION;
25use vortex_array::Precision;
26use vortex_array::ToCanonical;
27use vortex_array::VortexSessionExecute;
28use vortex_array::accessor::ArrayAccessor;
29use vortex_array::arrays::ConstantArray;
30use vortex_array::arrays::PrimitiveArray;
31use vortex_array::arrays::VarBinViewArray;
32use vortex_array::arrays::varbinview::build_views::BinaryView;
33use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
34use vortex_array::buffer::BufferHandle;
35use vortex_array::dtype::DType;
36use vortex_array::scalar::Scalar;
37use vortex_array::serde::ArrayChildren;
38use vortex_array::validity::Validity;
39use vortex_array::vtable::OperationsVTable;
40use vortex_array::vtable::VTable;
41use vortex_array::vtable::ValidityVTable;
42use vortex_array::vtable::child_to_validity;
43use vortex_array::vtable::validity_to_child;
44use vortex_buffer::Alignment;
45use vortex_buffer::Buffer;
46use vortex_buffer::BufferMut;
47use vortex_buffer::ByteBuffer;
48use vortex_buffer::ByteBufferMut;
49use vortex_error::VortexError;
50use vortex_error::VortexExpect;
51use vortex_error::VortexResult;
52use vortex_error::vortex_bail;
53use vortex_error::vortex_ensure;
54use vortex_error::vortex_err;
55use vortex_error::vortex_panic;
56use vortex_mask::AllOr;
57use vortex_session::VortexSession;
58use vortex_session::registry::CachedId;
59
60use crate::ZstdFrameMetadata;
61use crate::ZstdMetadata;
62
63const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
65type ViewLen = u32;
66
67pub type ZstdArray = Array<Zstd>;
87
88impl ArrayHash for ZstdData {
89 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
90 match &self.dictionary {
91 Some(dict) => {
92 true.hash(state);
93 dict.array_hash(state, precision);
94 }
95 None => {
96 false.hash(state);
97 }
98 }
99 for frame in &self.frames {
100 frame.array_hash(state, precision);
101 }
102 self.unsliced_n_rows.hash(state);
103 self.slice_start.hash(state);
104 self.slice_stop.hash(state);
105 }
106}
107
108impl ArrayEq for ZstdData {
109 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
110 if !match (&self.dictionary, &other.dictionary) {
111 (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
112 (None, None) => true,
113 _ => false,
114 } {
115 return false;
116 }
117 if self.frames.len() != other.frames.len() {
118 return false;
119 }
120 for (a, b) in self.frames.iter().zip(&other.frames) {
121 if !a.array_eq(b, precision) {
122 return false;
123 }
124 }
125 self.unsliced_n_rows == other.unsliced_n_rows
126 && self.slice_start == other.slice_start
127 && self.slice_stop == other.slice_stop
128 }
129}
130
131impl VTable for Zstd {
132 type ArrayData = ZstdData;
133
134 type OperationsVTable = Self;
135 type ValidityVTable = Self;
136
137 fn id(&self) -> ArrayId {
138 static ID: CachedId = CachedId::new("vortex.zstd");
139 *ID
140 }
141
142 fn validate(
143 &self,
144 data: &Self::ArrayData,
145 dtype: &DType,
146 len: usize,
147 slots: &[Option<ArrayRef>],
148 ) -> VortexResult<()> {
149 let validity = child_to_validity(&slots[0], dtype.nullability());
150 data.validate(dtype, len, &validity)
151 }
152
153 fn nbuffers(array: ArrayView<'_, Self>) -> usize {
154 array.dictionary.is_some() as usize + array.frames.len()
155 }
156
157 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
158 if let Some(dict) = &array.dictionary {
159 if idx == 0 {
160 return BufferHandle::new_host(dict.clone());
161 }
162 BufferHandle::new_host(array.frames[idx - 1].clone())
163 } else {
164 BufferHandle::new_host(array.frames[idx].clone())
165 }
166 }
167
168 fn buffer_name(array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
169 if array.dictionary.is_some() {
170 if idx == 0 {
171 Some("dictionary".to_string())
172 } else {
173 Some(format!("frame_{}", idx - 1))
174 }
175 } else {
176 Some(format!("frame_{idx}"))
177 }
178 }
179
180 fn serialize(
181 array: ArrayView<'_, Self>,
182 _session: &VortexSession,
183 ) -> VortexResult<Option<Vec<u8>>> {
184 Ok(Some(array.metadata.clone().encode_to_vec()))
185 }
186
187 fn deserialize(
188 &self,
189 dtype: &DType,
190 len: usize,
191 metadata: &[u8],
192 buffers: &[BufferHandle],
193 children: &dyn ArrayChildren,
194 _session: &VortexSession,
195 ) -> VortexResult<ArrayParts<Self>> {
196 let metadata = ZstdMetadata::decode(metadata)?;
197 let validity = if children.is_empty() {
198 Validity::from(dtype.nullability())
199 } else if children.len() == 1 {
200 let validity = children.get(0, &Validity::DTYPE, len)?;
201 Validity::Array(validity)
202 } else {
203 vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
204 };
205
206 let (dictionary_buffer, compressed_buffers) = if metadata.dictionary_size == 0 {
207 (
209 None,
210 buffers
211 .iter()
212 .map(|b| b.clone().try_to_host_sync())
213 .collect::<VortexResult<Vec<_>>>()?,
214 )
215 } else {
216 (
218 Some(buffers[0].clone().try_to_host_sync()?),
219 buffers[1..]
220 .iter()
221 .map(|b| b.clone().try_to_host_sync())
222 .collect::<VortexResult<Vec<_>>>()?,
223 )
224 };
225
226 let slots = vec![validity_to_child(&validity, len)];
227 let data = ZstdData::new(dictionary_buffer, compressed_buffers, metadata, len);
228 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
229 }
230
231 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
232 SLOT_NAMES[idx].to_string()
233 }
234
235 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
236 let unsliced_validity =
237 child_to_validity(&array.as_ref().slots()[0], array.dtype().nullability());
238 array
239 .data()
240 .decompress(array.dtype(), &unsliced_validity, ctx)?
241 .execute::<ArrayRef>(ctx)
242 .map(ExecutionResult::done)
243 }
244
245 fn reduce_parent(
246 array: ArrayView<'_, Self>,
247 parent: &ArrayRef,
248 child_idx: usize,
249 ) -> VortexResult<Option<ArrayRef>> {
250 crate::rules::RULES.evaluate(array, parent, child_idx)
251 }
252}
253
254#[derive(Clone, Debug)]
255pub struct Zstd;
256
257impl Zstd {
258 pub fn try_new(dtype: DType, data: ZstdData, validity: Validity) -> VortexResult<ZstdArray> {
259 let len = data.len();
260 data.validate(&dtype, len, &validity)?;
261 let slots = vec![validity_to_child(&validity, data.unsliced_n_rows())];
262 Ok(unsafe {
263 Array::from_parts_unchecked(ArrayParts::new(Zstd, dtype, len, data).with_slots(slots))
264 })
265 }
266
267 pub fn from_var_bin_view_without_dict(
269 vbv: &VarBinViewArray,
270 level: i32,
271 values_per_frame: usize,
272 ) -> VortexResult<ZstdArray> {
273 let validity = vbv.validity()?;
274 Self::try_new(
275 vbv.dtype().clone(),
276 ZstdData::from_var_bin_view_without_dict(vbv, level, values_per_frame)?,
277 validity,
278 )
279 }
280
281 pub fn from_primitive(
283 parray: &PrimitiveArray,
284 level: i32,
285 values_per_frame: usize,
286 ) -> VortexResult<ZstdArray> {
287 let validity = parray.validity()?;
288 Self::try_new(
289 parray.dtype().clone(),
290 ZstdData::from_primitive(parray, level, values_per_frame)?,
291 validity,
292 )
293 }
294
295 pub fn from_var_bin_view(
297 vbv: &VarBinViewArray,
298 level: i32,
299 values_per_frame: usize,
300 ) -> VortexResult<ZstdArray> {
301 let validity = vbv.validity()?;
302 Self::try_new(
303 vbv.dtype().clone(),
304 ZstdData::from_var_bin_view(vbv, level, values_per_frame)?,
305 validity,
306 )
307 }
308
309 pub fn decompress(array: &ZstdArray, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
310 let unsliced_validity =
311 child_to_validity(&array.as_ref().slots()[0], array.dtype().nullability());
312 array
313 .data()
314 .decompress(array.dtype(), &unsliced_validity, ctx)
315 }
316}
317
318pub(super) const NUM_SLOTS: usize = 1;
320pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
321
322#[derive(Clone, Debug)]
323pub struct ZstdData {
324 pub(crate) dictionary: Option<ByteBuffer>,
325 pub(crate) frames: Vec<ByteBuffer>,
326 pub(crate) metadata: ZstdMetadata,
327 unsliced_n_rows: usize,
328 slice_start: usize,
329 slice_stop: usize,
330}
331
332impl Display for ZstdData {
333 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
334 write!(
335 f,
336 "nrows: {}, slice: {}..{}",
337 self.unsliced_n_rows, self.slice_start, self.slice_stop
338 )
339 }
340}
341
342pub struct ZstdDataParts {
343 pub dictionary: Option<ByteBuffer>,
344 pub frames: Vec<ByteBuffer>,
345 pub metadata: ZstdMetadata,
346 pub validity: Validity,
347 pub n_rows: usize,
348 pub slice_start: usize,
349 pub slice_stop: usize,
350}
351
352#[derive(Debug)]
354struct Frames {
355 dictionary: Option<ByteBuffer>,
356 frames: Vec<ByteBuffer>,
357 frame_metas: Vec<ZstdFrameMetadata>,
358}
359
360fn choose_max_dict_size(uncompressed_size: usize) -> usize {
361 (uncompressed_size / 100).clamp(256, 100 * 1024)
366}
367
368fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
369 let mask = parray.as_ref().validity()?.to_mask(
370 parray.as_ref().len(),
371 &mut LEGACY_SESSION.create_execution_ctx(),
372 )?;
373 Ok(parray.filter(mask)?.to_primitive())
374}
375
376fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
377 let mask = vbv.as_ref().validity()?.to_mask(
378 vbv.as_ref().len(),
379 &mut LEGACY_SESSION.create_execution_ctx(),
380 )?;
381 let buffer_and_value_byte_indices = match mask.bit_buffer() {
382 AllOr::None => (Buffer::empty(), Vec::new()),
383 _ => {
384 let mut buffer = BufferMut::with_capacity(
385 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
386 + mask.true_count() * size_of::<ViewLen>(),
387 );
388 let mut value_byte_indices = Vec::new();
389 vbv.with_iterator(|iterator| {
390 for value in iterator.flatten() {
392 value_byte_indices.push(buffer.len());
393 buffer
395 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
396 buffer.extend_from_slice(value);
397 }
398 Ok::<_, VortexError>(())
399 })?;
400 (buffer.freeze(), value_byte_indices)
401 }
402 };
403 Ok(buffer_and_value_byte_indices)
404}
405
406pub fn reconstruct_views(
415 buffer: &ByteBuffer,
416 max_buffer_len: usize,
417) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
418 let mut views = BufferMut::<BinaryView>::empty();
419 let mut buffers = Vec::new();
420 let mut segment_start: usize = 0;
421 let mut offset = 0;
422
423 while offset < buffer.len() {
424 let str_len = ViewLen::from_le_bytes(
425 buffer
426 .get(offset..offset + size_of::<ViewLen>())
427 .vortex_expect("corrupted zstd length")
428 .try_into()
429 .ok()
430 .vortex_expect("must fit ViewLen size"),
431 ) as usize;
432
433 let value_data_offset = offset + size_of::<ViewLen>();
434 let local_offset = value_data_offset - segment_start;
435
436 if local_offset + str_len > max_buffer_len && offset > segment_start {
437 buffers.push(buffer.slice(segment_start..offset));
438 segment_start = offset;
439 }
440
441 let local_offset = u32::try_from(value_data_offset - segment_start)
442 .vortex_expect("local offset within segment must fit in u32");
443 let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
444 let value = &buffer[value_data_offset..value_data_offset + str_len];
445 views.push(BinaryView::make_view(value, buf_index, local_offset));
446 offset = value_data_offset + str_len;
447 }
448
449 if segment_start < buffer.len() {
450 buffers.push(buffer.slice(segment_start..buffer.len()));
451 }
452
453 (buffers, views.freeze())
454}
455
456impl ZstdData {
457 pub fn new(
458 dictionary: Option<ByteBuffer>,
459 frames: Vec<ByteBuffer>,
460 metadata: ZstdMetadata,
461 n_rows: usize,
462 ) -> Self {
463 Self {
464 dictionary,
465 frames,
466 metadata,
467 unsliced_n_rows: n_rows,
468 slice_start: 0,
469 slice_stop: n_rows,
470 }
471 }
472
473 pub fn validate(&self, dtype: &DType, len: usize, validity: &Validity) -> VortexResult<()> {
474 vortex_ensure!(
475 matches!(
476 dtype,
477 DType::Primitive(..) | DType::Binary(_) | DType::Utf8(_)
478 ),
479 "Unsupported dtype for Zstd array: {dtype}"
480 );
481 vortex_ensure!(
482 self.slice_start <= self.slice_stop,
483 "Invalid slice range {}..{}",
484 self.slice_start,
485 self.slice_stop
486 );
487 vortex_ensure!(
488 self.slice_stop <= self.unsliced_n_rows,
489 "Slice stop {} exceeds unsliced row count {}",
490 self.slice_stop,
491 self.unsliced_n_rows
492 );
493 vortex_ensure!(
494 self.slice_stop - self.slice_start == len,
495 "Slice length {} does not match array length {}",
496 self.slice_stop - self.slice_start,
497 len
498 );
499 if let Some(validity_len) = validity.maybe_len() {
500 vortex_ensure!(
501 validity_len == self.unsliced_n_rows,
502 "Validity length {} does not match unsliced row count {}",
503 validity_len,
504 self.unsliced_n_rows
505 );
506 }
507
508 match &self.dictionary {
509 Some(dictionary) => vortex_ensure!(
510 usize::try_from(self.metadata.dictionary_size)? == dictionary.len(),
511 "Dictionary size metadata {} does not match buffer size {}",
512 self.metadata.dictionary_size,
513 dictionary.len()
514 ),
515 None => vortex_ensure!(
516 self.metadata.dictionary_size == 0,
517 "Dictionary metadata present without dictionary buffer"
518 ),
519 }
520 vortex_ensure!(
521 self.frames.len() == self.metadata.frames.len(),
522 "Frame count {} does not match metadata frame count {}",
523 self.frames.len(),
524 self.metadata.frames.len()
525 );
526
527 Ok(())
528 }
529
530 pub(crate) fn with_slice(&self, start: usize, stop: usize) -> Self {
531 let new_start = self.slice_start + start;
532 let new_stop = self.slice_start + stop;
533
534 assert!(
535 new_start <= self.slice_stop,
536 "new slice start {new_start} exceeds end {}",
537 self.slice_stop
538 );
539
540 assert!(
541 new_stop <= self.slice_stop,
542 "new slice stop {new_stop} exceeds end {}",
543 self.slice_stop
544 );
545
546 Self {
547 slice_start: new_start,
548 slice_stop: new_stop,
549 ..self.clone()
550 }
551 }
552
553 fn compress_values(
554 value_bytes: &ByteBuffer,
555 frame_byte_starts: &[usize],
556 level: i32,
557 values_per_frame: usize,
558 n_values: usize,
559 use_dictionary: bool,
560 ) -> VortexResult<Frames> {
561 let n_frames = frame_byte_starts.len();
562
563 let mut sample_sizes = Vec::with_capacity(n_frames);
565 for i in 0..n_frames {
566 let frame_byte_end = frame_byte_starts
567 .get(i + 1)
568 .copied()
569 .unwrap_or(value_bytes.len());
570 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
571 }
572 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
573
574 let (dictionary, mut compressor) = if !use_dictionary
575 || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
576 {
577 (None, zstd::bulk::Compressor::new(level)?)
579 } else {
580 let max_dict_size = choose_max_dict_size(value_bytes.len());
582 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
583 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
584
585 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
586 (Some(ByteBuffer::from(dict)), compressor)
587 };
588
589 let mut frame_metas = vec![];
590 let mut frames = vec![];
591 for i in 0..n_frames {
592 let frame_byte_end = frame_byte_starts
593 .get(i + 1)
594 .copied()
595 .unwrap_or(value_bytes.len());
596
597 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
598 let compressed = compressor
599 .compress(uncompressed)
600 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
601 frame_metas.push(ZstdFrameMetadata {
602 uncompressed_size: uncompressed.len() as u64,
603 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
604 });
605 frames.push(ByteBuffer::from(compressed));
606 }
607
608 Ok(Frames {
609 dictionary,
610 frames,
611 frame_metas,
612 })
613 }
614
615 pub fn from_primitive(
622 parray: &PrimitiveArray,
623 level: i32,
624 values_per_frame: usize,
625 ) -> VortexResult<Self> {
626 Self::from_primitive_impl(parray, level, values_per_frame, true)
627 }
628
629 pub fn from_primitive_without_dict(
643 parray: &PrimitiveArray,
644 level: i32,
645 values_per_frame: usize,
646 ) -> VortexResult<Self> {
647 Self::from_primitive_impl(parray, level, values_per_frame, false)
648 }
649
650 fn from_primitive_impl(
651 parray: &PrimitiveArray,
652 level: i32,
653 values_per_frame: usize,
654 use_dictionary: bool,
655 ) -> VortexResult<Self> {
656 let byte_width = parray.ptype().byte_width();
657
658 let values = collect_valid_primitive(parray)?;
660 let n_values = values.len();
661 let values_per_frame = if values_per_frame > 0 {
662 values_per_frame
663 } else {
664 n_values
665 };
666
667 let value_bytes = values.buffer_handle().try_to_host_sync()?;
668 let alignment = *value_bytes.alignment();
670 let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
671
672 let frame_byte_starts = (0..n_values * byte_width)
673 .step_by(step_width)
674 .collect::<Vec<_>>();
675 let Frames {
676 dictionary,
677 frames,
678 frame_metas,
679 } = Self::compress_values(
680 &value_bytes,
681 &frame_byte_starts,
682 level,
683 values_per_frame,
684 n_values,
685 use_dictionary,
686 )?;
687
688 let metadata = ZstdMetadata {
689 dictionary_size: dictionary
690 .as_ref()
691 .map_or(0, |dict| dict.len())
692 .try_into()?,
693 frames: frame_metas,
694 };
695
696 Ok(ZstdData::new(dictionary, frames, metadata, parray.len()))
697 }
698
699 pub fn from_var_bin_view(
706 vbv: &VarBinViewArray,
707 level: i32,
708 values_per_frame: usize,
709 ) -> VortexResult<Self> {
710 Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
711 }
712
713 pub fn from_var_bin_view_without_dict(
727 vbv: &VarBinViewArray,
728 level: i32,
729 values_per_frame: usize,
730 ) -> VortexResult<Self> {
731 Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
732 }
733
734 fn from_var_bin_view_impl(
735 vbv: &VarBinViewArray,
736 level: i32,
737 values_per_frame: usize,
738 use_dictionary: bool,
739 ) -> VortexResult<Self> {
740 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
747 let n_values = value_byte_indices.len();
748 let values_per_frame = if values_per_frame > 0 {
749 values_per_frame
750 } else {
751 n_values
752 };
753
754 let frame_byte_starts = (0..n_values)
755 .step_by(values_per_frame)
756 .map(|i| value_byte_indices[i])
757 .collect::<Vec<_>>();
758 let Frames {
759 dictionary,
760 frames,
761 frame_metas,
762 } = Self::compress_values(
763 &value_bytes,
764 &frame_byte_starts,
765 level,
766 values_per_frame,
767 n_values,
768 use_dictionary,
769 )?;
770
771 let metadata = ZstdMetadata {
772 dictionary_size: dictionary
773 .as_ref()
774 .map_or(0, |dict| dict.len())
775 .try_into()?,
776 frames: frame_metas,
777 };
778 Ok(ZstdData::new(dictionary, frames, metadata, vbv.len()))
779 }
780
781 pub fn from_canonical(
782 canonical: &Canonical,
783 level: i32,
784 values_per_frame: usize,
785 ) -> VortexResult<Option<Self>> {
786 match canonical {
787 Canonical::Primitive(parray) => Ok(Some(ZstdData::from_primitive(
788 parray,
789 level,
790 values_per_frame,
791 )?)),
792 Canonical::VarBinView(vbv) => Ok(Some(ZstdData::from_var_bin_view(
793 vbv,
794 level,
795 values_per_frame,
796 )?)),
797 _ => Ok(None),
798 }
799 }
800
801 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
802 Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
803 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
804 }
805
806 fn byte_width(dtype: &DType) -> usize {
807 if dtype.is_primitive() {
808 dtype.as_ptype().byte_width()
809 } else {
810 1
811 }
812 }
813
814 fn decompress(
815 &self,
816 dtype: &DType,
817 unsliced_validity: &Validity,
818 ctx: &mut ExecutionCtx,
819 ) -> VortexResult<ArrayRef> {
820 let byte_width = Self::byte_width(dtype);
823 let slice_n_rows = self.slice_stop - self.slice_start;
824 let slice_value_indices = unsliced_validity
825 .execute_mask(self.unsliced_n_rows, ctx)?
826 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
827
828 let slice_value_idx_start = slice_value_indices[0];
829 let slice_value_idx_stop = slice_value_indices[1];
830
831 let mut frames_to_decompress = vec![];
832 let mut value_idx_start = 0;
833 let mut uncompressed_size_to_decompress = 0;
834 let mut n_skipped_values = 0;
835 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
836 if value_idx_start >= slice_value_idx_stop {
837 break;
838 }
839
840 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
841 .vortex_expect("Uncompressed size must fit in usize");
842 let frame_n_values = if frame_meta.n_values == 0 {
843 frame_uncompressed_size / byte_width
845 } else {
846 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
847 };
848
849 let value_idx_stop = value_idx_start + frame_n_values;
850 if value_idx_stop > slice_value_idx_start {
851 frames_to_decompress.push(frame);
853 uncompressed_size_to_decompress += frame_uncompressed_size;
854 } else {
855 n_skipped_values += frame_n_values;
856 }
857 value_idx_start = value_idx_stop;
858 }
859
860 let mut decompressor = if let Some(dictionary) = &self.dictionary {
862 zstd::bulk::Decompressor::with_dictionary(dictionary)?
863 } else {
864 zstd::bulk::Decompressor::new()?
865 };
866 let mut decompressed = ByteBufferMut::with_capacity_aligned(
867 uncompressed_size_to_decompress,
868 Alignment::new(byte_width),
869 );
870 unsafe {
871 decompressed.set_len(uncompressed_size_to_decompress);
874 }
875 let mut uncompressed_start = 0;
876 for frame in frames_to_decompress {
877 let uncompressed_written = decompressor
878 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
879 uncompressed_start += uncompressed_written;
880 }
881 if uncompressed_start != uncompressed_size_to_decompress {
882 vortex_panic!(
883 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
884 uncompressed_size_to_decompress,
885 uncompressed_start
886 );
887 }
888
889 let decompressed = decompressed.freeze();
890 let mut slice_validity = unsliced_validity.slice(self.slice_start..self.slice_stop)?;
892
893 if !dtype.is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
902 assert!(
903 matches!(slice_validity, Validity::AllValid),
904 "ZSTD array expects to be non-nullable but there are nulls after decompression"
905 );
906
907 slice_validity = Validity::NonNullable;
908 } else if dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
909 slice_validity = Validity::AllValid;
910 }
911 match dtype {
916 DType::Primitive(..) => {
917 let slice_values_buffer = decompressed.slice(
918 (slice_value_idx_start - n_skipped_values) * byte_width
919 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
920 );
921 let primitive = PrimitiveArray::from_values_byte_buffer(
922 slice_values_buffer,
923 dtype.as_ptype(),
924 slice_validity,
925 slice_n_rows,
926 );
927
928 Ok(primitive.into_array())
929 }
930 DType::Binary(_) | DType::Utf8(_) => {
931 match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
932 AllOr::All => {
933 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
934 let valid_views = all_views.slice(
935 slice_value_idx_start - n_skipped_values
936 ..slice_value_idx_stop - n_skipped_values,
937 );
938
939 Ok(unsafe {
941 VarBinViewArray::new_unchecked(
942 valid_views,
943 Arc::from(buffers),
944 dtype.clone(),
945 slice_validity,
946 )
947 }
948 .into_array())
949 }
950 AllOr::None => Ok(ConstantArray::new(
951 Scalar::null(dtype.clone()),
952 slice_n_rows,
953 )
954 .into_array()),
955 AllOr::Some(valid_indices) => {
956 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
957 let valid_views = all_views.slice(
958 slice_value_idx_start - n_skipped_values
959 ..slice_value_idx_stop - n_skipped_values,
960 );
961
962 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
963 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
964 views[*index] = view
965 }
966
967 Ok(unsafe {
969 VarBinViewArray::new_unchecked(
970 views.freeze(),
971 Arc::from(buffers),
972 dtype.clone(),
973 slice_validity,
974 )
975 }
976 .into_array())
977 }
978 }
979 }
980 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", dtype),
981 }
982 }
983
984 #[inline]
986 pub fn len(&self) -> usize {
987 self.slice_stop - self.slice_start
988 }
989
990 #[inline]
992 pub fn is_empty(&self) -> bool {
993 self.slice_stop == self.slice_start
994 }
995
996 pub fn into_parts(self, validity: Validity) -> ZstdDataParts {
997 ZstdDataParts {
998 dictionary: self.dictionary,
999 frames: self.frames,
1000 metadata: self.metadata,
1001 validity,
1002 n_rows: self.unsliced_n_rows,
1003 slice_start: self.slice_start,
1004 slice_stop: self.slice_stop,
1005 }
1006 }
1007
1008 pub(crate) fn slice_start(&self) -> usize {
1009 self.slice_start
1010 }
1011
1012 pub(crate) fn slice_stop(&self) -> usize {
1013 self.slice_stop
1014 }
1015
1016 pub(crate) fn unsliced_n_rows(&self) -> usize {
1017 self.unsliced_n_rows
1018 }
1019}
1020
1021impl ValidityVTable<Zstd> for Zstd {
1022 fn validity(array: ArrayView<'_, Zstd>) -> VortexResult<Validity> {
1023 let unsliced_validity = child_to_validity(&array.slots()[0], array.dtype().nullability());
1024 unsliced_validity.slice(array.slice_start()..array.slice_stop())
1025 }
1026}
1027
1028impl OperationsVTable<Zstd> for Zstd {
1029 fn scalar_at(
1030 array: ArrayView<'_, Zstd>,
1031 index: usize,
1032 ctx: &mut ExecutionCtx,
1033 ) -> VortexResult<Scalar> {
1034 let unsliced_validity = child_to_validity(&array.slots()[0], array.dtype().nullability());
1035 let sliced = array.data().with_slice(index, index + 1);
1036 sliced
1037 .decompress(array.dtype(), &unsliced_validity, ctx)?
1038 .execute_scalar(0, ctx)
1039 }
1040}
1041
1042#[cfg(test)]
1043#[expect(clippy::cast_possible_truncation)]
1044mod tests {
1045 use vortex_buffer::ByteBuffer;
1046
1047 use super::reconstruct_views;
1048 use crate::array::BinaryView;
1049
1050 fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
1052 let mut buf = Vec::new();
1053 for s in strings {
1054 let len = s.len() as u32;
1055 buf.extend_from_slice(&len.to_le_bytes());
1056 buf.extend_from_slice(s);
1057 }
1058 ByteBuffer::copy_from(buf.as_slice())
1059 }
1060
1061 #[test]
1062 fn test_reconstruct_views_no_split() {
1063 let strings: &[&[u8]] = &[b"hello", b"world"];
1064 let buf = make_interleaved(strings);
1065 let (buffers, views) = reconstruct_views(&buf, 1024);
1066
1067 assert_eq!(buffers.len(), 1);
1068 assert_eq!(views.len(), 2);
1069 assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1071 assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1072 }
1073
1074 #[test]
1075 fn test_reconstruct_views_split_across_segments() {
1076 let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1081 let buf = make_interleaved(strings);
1082 let (buffers, views) = reconstruct_views(&buf, 20);
1083
1084 assert_eq!(buffers.len(), 2);
1085 assert_eq!(views.len(), 2);
1086 assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1087 assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1089 }
1090}