1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use itertools::Itertools as _;
12use prost::Message as _;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayId;
17use vortex_array::ArrayParts;
18use vortex_array::ArrayRef;
19use vortex_array::ArrayView;
20use vortex_array::Canonical;
21use vortex_array::ExecutionCtx;
22use vortex_array::ExecutionResult;
23use vortex_array::IntoArray;
24use vortex_array::Precision;
25use vortex_array::accessor::ArrayAccessor;
26use vortex_array::arrays::ConstantArray;
27use vortex_array::arrays::PrimitiveArray;
28use vortex_array::arrays::VarBinViewArray;
29use vortex_array::arrays::varbinview::build_views::BinaryView;
30use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
31use vortex_array::buffer::BufferHandle;
32use vortex_array::dtype::DType;
33use vortex_array::scalar::Scalar;
34use vortex_array::serde::ArrayChildren;
35use vortex_array::smallvec::smallvec;
36use vortex_array::validity::Validity;
37use vortex_array::vtable::OperationsVTable;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityVTable;
40use vortex_array::vtable::child_to_validity;
41use vortex_array::vtable::validity_to_child;
42use vortex_buffer::Alignment;
43use vortex_buffer::Buffer;
44use vortex_buffer::BufferMut;
45use vortex_buffer::ByteBuffer;
46use vortex_buffer::ByteBufferMut;
47use vortex_error::VortexError;
48use vortex_error::VortexExpect;
49use vortex_error::VortexResult;
50use vortex_error::vortex_bail;
51use vortex_error::vortex_ensure;
52use vortex_error::vortex_err;
53use vortex_error::vortex_panic;
54use vortex_mask::AllOr;
55use vortex_session::VortexSession;
56use vortex_session::registry::CachedId;
57
58use crate::ZstdFrameMetadata;
59use crate::ZstdMetadata;
60
61const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
63type ViewLen = u32;
64
65pub type ZstdArray = Array<Zstd>;
85
86impl ArrayHash for ZstdData {
87 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
88 match &self.dictionary {
89 Some(dict) => {
90 true.hash(state);
91 dict.array_hash(state, precision);
92 }
93 None => {
94 false.hash(state);
95 }
96 }
97 for frame in &self.frames {
98 frame.array_hash(state, precision);
99 }
100 self.unsliced_n_rows.hash(state);
101 self.slice_start.hash(state);
102 self.slice_stop.hash(state);
103 }
104}
105
106impl ArrayEq for ZstdData {
107 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
108 if !match (&self.dictionary, &other.dictionary) {
109 (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
110 (None, None) => true,
111 _ => false,
112 } {
113 return false;
114 }
115 if self.frames.len() != other.frames.len() {
116 return false;
117 }
118 for (a, b) in self.frames.iter().zip(&other.frames) {
119 if !a.array_eq(b, precision) {
120 return false;
121 }
122 }
123 self.unsliced_n_rows == other.unsliced_n_rows
124 && self.slice_start == other.slice_start
125 && self.slice_stop == other.slice_stop
126 }
127}
128
129impl VTable for Zstd {
130 type TypedArrayData = ZstdData;
131
132 type OperationsVTable = Self;
133 type ValidityVTable = Self;
134
135 fn id(&self) -> ArrayId {
136 static ID: CachedId = CachedId::new("vortex.zstd");
137 *ID
138 }
139
140 fn validate(
141 &self,
142 data: &Self::TypedArrayData,
143 dtype: &DType,
144 len: usize,
145 slots: &[Option<ArrayRef>],
146 ) -> VortexResult<()> {
147 let validity = child_to_validity(slots[0].as_ref(), dtype.nullability());
148 data.validate(dtype, len, &validity)
149 }
150
151 fn nbuffers(array: ArrayView<'_, Self>) -> usize {
152 array.dictionary.is_some() as usize + array.frames.len()
153 }
154
155 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
156 if let Some(dict) = &array.dictionary {
157 if idx == 0 {
158 return BufferHandle::new_host(dict.clone());
159 }
160 BufferHandle::new_host(array.frames[idx - 1].clone())
161 } else {
162 BufferHandle::new_host(array.frames[idx].clone())
163 }
164 }
165
166 fn buffer_name(array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
167 if array.dictionary.is_some() {
168 if idx == 0 {
169 Some("dictionary".to_string())
170 } else {
171 Some(format!("frame_{}", idx - 1))
172 }
173 } else {
174 Some(format!("frame_{idx}"))
175 }
176 }
177
178 fn serialize(
179 array: ArrayView<'_, Self>,
180 _session: &VortexSession,
181 ) -> VortexResult<Option<Vec<u8>>> {
182 Ok(Some(array.metadata.clone().encode_to_vec()))
183 }
184
185 fn deserialize(
186 &self,
187 dtype: &DType,
188 len: usize,
189 metadata: &[u8],
190 buffers: &[BufferHandle],
191 children: &dyn ArrayChildren,
192 _session: &VortexSession,
193 ) -> VortexResult<ArrayParts<Self>> {
194 let metadata = ZstdMetadata::decode(metadata)?;
195 let validity = if children.is_empty() {
196 Validity::from(dtype.nullability())
197 } else if children.len() == 1 {
198 let validity = children.get(0, &Validity::DTYPE, len)?;
199 Validity::Array(validity)
200 } else {
201 vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
202 };
203
204 let (dictionary_buffer, compressed_buffers) = if metadata.dictionary_size == 0 {
205 (
207 None,
208 buffers
209 .iter()
210 .map(|b| b.clone().try_to_host_sync())
211 .collect::<VortexResult<Vec<_>>>()?,
212 )
213 } else {
214 (
216 Some(buffers[0].clone().try_to_host_sync()?),
217 buffers[1..]
218 .iter()
219 .map(|b| b.clone().try_to_host_sync())
220 .collect::<VortexResult<Vec<_>>>()?,
221 )
222 };
223
224 let slots = smallvec![validity_to_child(&validity, len)];
225 let data = ZstdData::new(dictionary_buffer, compressed_buffers, metadata, len);
226 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
227 }
228
229 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
230 SLOT_NAMES[idx].to_string()
231 }
232
233 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
234 let unsliced_validity = child_to_validity(
235 array.as_ref().slots()[0].as_ref(),
236 array.dtype().nullability(),
237 );
238 array
239 .data()
240 .decompress(array.dtype(), &unsliced_validity, ctx)?
241 .execute::<ArrayRef>(ctx)
242 .map(ExecutionResult::done)
243 }
244
245 fn reduce_parent(
246 array: ArrayView<'_, Self>,
247 parent: &ArrayRef,
248 child_idx: usize,
249 ) -> VortexResult<Option<ArrayRef>> {
250 crate::rules::RULES.evaluate(array, parent, child_idx)
251 }
252}
253
254#[derive(Clone, Debug)]
255pub struct Zstd;
256
257impl Zstd {
258 pub fn try_new(dtype: DType, data: ZstdData, validity: Validity) -> VortexResult<ZstdArray> {
259 let len = data.len();
260 data.validate(&dtype, len, &validity)?;
261 let slots = smallvec![validity_to_child(&validity, data.unsliced_n_rows())];
262 Ok(unsafe {
263 Array::from_parts_unchecked(ArrayParts::new(Zstd, dtype, len, data).with_slots(slots))
264 })
265 }
266
267 pub fn from_var_bin_view_without_dict(
269 vbv: &VarBinViewArray,
270 level: i32,
271 values_per_frame: usize,
272 ctx: &mut ExecutionCtx,
273 ) -> VortexResult<ZstdArray> {
274 let validity = vbv.validity()?;
275 Self::try_new(
276 vbv.dtype().clone(),
277 ZstdData::from_var_bin_view_without_dict(vbv, level, values_per_frame, ctx)?,
278 validity,
279 )
280 }
281
282 pub fn from_primitive(
284 parray: &PrimitiveArray,
285 level: i32,
286 values_per_frame: usize,
287 ctx: &mut ExecutionCtx,
288 ) -> VortexResult<ZstdArray> {
289 let validity = parray.validity()?;
290 Self::try_new(
291 parray.dtype().clone(),
292 ZstdData::from_primitive(parray, level, values_per_frame, ctx)?,
293 validity,
294 )
295 }
296
297 pub fn from_var_bin_view(
299 vbv: &VarBinViewArray,
300 level: i32,
301 values_per_frame: usize,
302 ctx: &mut ExecutionCtx,
303 ) -> VortexResult<ZstdArray> {
304 let validity = vbv.validity()?;
305 Self::try_new(
306 vbv.dtype().clone(),
307 ZstdData::from_var_bin_view(vbv, level, values_per_frame, ctx)?,
308 validity,
309 )
310 }
311
312 pub fn decompress(array: &ZstdArray, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
313 let unsliced_validity = child_to_validity(
314 array.as_ref().slots()[0].as_ref(),
315 array.dtype().nullability(),
316 );
317 array
318 .data()
319 .decompress(array.dtype(), &unsliced_validity, ctx)
320 }
321}
322
323pub(super) const NUM_SLOTS: usize = 1;
325pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
326
327#[derive(Clone, Debug)]
328pub struct ZstdData {
329 pub(crate) dictionary: Option<ByteBuffer>,
330 pub(crate) frames: Vec<ByteBuffer>,
331 pub(crate) metadata: ZstdMetadata,
332 unsliced_n_rows: usize,
333 slice_start: usize,
334 slice_stop: usize,
335}
336
337impl Display for ZstdData {
338 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
339 write!(
340 f,
341 "nrows: {}, slice: {}..{}",
342 self.unsliced_n_rows, self.slice_start, self.slice_stop
343 )
344 }
345}
346
347pub struct ZstdDataParts {
348 pub dictionary: Option<ByteBuffer>,
349 pub frames: Vec<ByteBuffer>,
350 pub metadata: ZstdMetadata,
351 pub validity: Validity,
352 pub n_rows: usize,
353 pub slice_start: usize,
354 pub slice_stop: usize,
355}
356
357#[derive(Debug)]
359struct Frames {
360 dictionary: Option<ByteBuffer>,
361 frames: Vec<ByteBuffer>,
362 frame_metas: Vec<ZstdFrameMetadata>,
363}
364
365fn choose_max_dict_size(uncompressed_size: usize) -> usize {
366 (uncompressed_size / 100).clamp(256, 100 * 1024)
371}
372
373fn collect_valid_primitive(
374 parray: &PrimitiveArray,
375 ctx: &mut ExecutionCtx,
376) -> VortexResult<PrimitiveArray> {
377 let mask = parray
378 .as_ref()
379 .validity()?
380 .execute_mask(parray.as_ref().len(), ctx)?;
381 let result = parray.filter(mask)?.execute::<PrimitiveArray>(ctx)?;
382 Ok(result)
383}
384
385fn collect_valid_vbv(
386 vbv: &VarBinViewArray,
387 ctx: &mut ExecutionCtx,
388) -> VortexResult<(ByteBuffer, Vec<usize>)> {
389 let mask = vbv
390 .as_ref()
391 .validity()?
392 .execute_mask(vbv.as_ref().len(), ctx)?;
393 let buffer_and_value_byte_indices = match mask.bit_buffer() {
394 AllOr::None => (Buffer::empty(), Vec::new()),
395 _ => {
396 let mut buffer = BufferMut::with_capacity(
397 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
398 + mask.true_count() * size_of::<ViewLen>(),
399 );
400 let mut value_byte_indices = Vec::new();
401 vbv.with_iterator(|iterator| {
402 for value in iterator.flatten() {
404 value_byte_indices.push(buffer.len());
405 buffer
407 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
408 buffer.extend_from_slice(value);
409 }
410 Ok::<_, VortexError>(())
411 })?;
412 (buffer.freeze(), value_byte_indices)
413 }
414 };
415 Ok(buffer_and_value_byte_indices)
416}
417
418pub fn reconstruct_views(
427 buffer: &ByteBuffer,
428 max_buffer_len: usize,
429) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
430 let mut views = BufferMut::<BinaryView>::empty();
431 let mut buffers = Vec::new();
432 let mut segment_start: usize = 0;
433 let mut offset = 0;
434
435 while offset < buffer.len() {
436 let str_len = ViewLen::from_le_bytes(
437 buffer
438 .get(offset..offset + size_of::<ViewLen>())
439 .vortex_expect("corrupted zstd length")
440 .try_into()
441 .ok()
442 .vortex_expect("must fit ViewLen size"),
443 ) as usize;
444
445 let value_data_offset = offset + size_of::<ViewLen>();
446 let local_offset = value_data_offset - segment_start;
447
448 if local_offset + str_len > max_buffer_len && offset > segment_start {
449 buffers.push(buffer.slice(segment_start..offset));
450 segment_start = offset;
451 }
452
453 let local_offset = u32::try_from(value_data_offset - segment_start)
454 .vortex_expect("local offset within segment must fit in u32");
455 let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
456 let value = &buffer[value_data_offset..value_data_offset + str_len];
457 views.push(BinaryView::make_view(value, buf_index, local_offset));
458 offset = value_data_offset + str_len;
459 }
460
461 if segment_start < buffer.len() {
462 buffers.push(buffer.slice(segment_start..buffer.len()));
463 }
464
465 (buffers, views.freeze())
466}
467
468impl ZstdData {
469 pub fn new(
470 dictionary: Option<ByteBuffer>,
471 frames: Vec<ByteBuffer>,
472 metadata: ZstdMetadata,
473 n_rows: usize,
474 ) -> Self {
475 Self {
476 dictionary,
477 frames,
478 metadata,
479 unsliced_n_rows: n_rows,
480 slice_start: 0,
481 slice_stop: n_rows,
482 }
483 }
484
485 pub fn validate(&self, dtype: &DType, len: usize, validity: &Validity) -> VortexResult<()> {
486 vortex_ensure!(
487 matches!(
488 dtype,
489 DType::Primitive(..) | DType::Binary(_) | DType::Utf8(_)
490 ),
491 "Unsupported dtype for Zstd array: {dtype}"
492 );
493 vortex_ensure!(
494 self.slice_start <= self.slice_stop,
495 "Invalid slice range {}..{}",
496 self.slice_start,
497 self.slice_stop
498 );
499 vortex_ensure!(
500 self.slice_stop <= self.unsliced_n_rows,
501 "Slice stop {} exceeds unsliced row count {}",
502 self.slice_stop,
503 self.unsliced_n_rows
504 );
505 vortex_ensure!(
506 self.slice_stop - self.slice_start == len,
507 "Slice length {} does not match array length {}",
508 self.slice_stop - self.slice_start,
509 len
510 );
511 if let Some(validity_len) = validity.maybe_len() {
512 vortex_ensure!(
513 validity_len == self.unsliced_n_rows,
514 "Validity length {} does not match unsliced row count {}",
515 validity_len,
516 self.unsliced_n_rows
517 );
518 }
519
520 match &self.dictionary {
521 Some(dictionary) => vortex_ensure!(
522 usize::try_from(self.metadata.dictionary_size)? == dictionary.len(),
523 "Dictionary size metadata {} does not match buffer size {}",
524 self.metadata.dictionary_size,
525 dictionary.len()
526 ),
527 None => vortex_ensure!(
528 self.metadata.dictionary_size == 0,
529 "Dictionary metadata present without dictionary buffer"
530 ),
531 }
532 vortex_ensure!(
533 self.frames.len() == self.metadata.frames.len(),
534 "Frame count {} does not match metadata frame count {}",
535 self.frames.len(),
536 self.metadata.frames.len()
537 );
538
539 Ok(())
540 }
541
542 pub(crate) fn with_slice(&self, start: usize, stop: usize) -> Self {
543 let new_start = self.slice_start + start;
544 let new_stop = self.slice_start + stop;
545
546 assert!(
547 new_start <= self.slice_stop,
548 "new slice start {new_start} exceeds end {}",
549 self.slice_stop
550 );
551
552 assert!(
553 new_stop <= self.slice_stop,
554 "new slice stop {new_stop} exceeds end {}",
555 self.slice_stop
556 );
557
558 Self {
559 slice_start: new_start,
560 slice_stop: new_stop,
561 ..self.clone()
562 }
563 }
564
565 fn compress_values(
566 value_bytes: &ByteBuffer,
567 frame_byte_starts: &[usize],
568 level: i32,
569 values_per_frame: usize,
570 n_values: usize,
571 use_dictionary: bool,
572 ) -> VortexResult<Frames> {
573 let n_frames = frame_byte_starts.len();
574
575 let mut sample_sizes = Vec::with_capacity(n_frames);
577 for i in 0..n_frames {
578 let frame_byte_end = frame_byte_starts
579 .get(i + 1)
580 .copied()
581 .unwrap_or(value_bytes.len());
582 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
583 }
584 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
585
586 let (dictionary, mut compressor) = if !use_dictionary
587 || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
588 {
589 (None, zstd::bulk::Compressor::new(level)?)
591 } else {
592 let max_dict_size = choose_max_dict_size(value_bytes.len());
594 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
595 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
596
597 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
598 (Some(ByteBuffer::from(dict)), compressor)
599 };
600
601 let mut frame_metas = vec![];
602 let mut frames = vec![];
603 for i in 0..n_frames {
604 let frame_byte_end = frame_byte_starts
605 .get(i + 1)
606 .copied()
607 .unwrap_or(value_bytes.len());
608
609 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
610 let compressed = compressor
611 .compress(uncompressed)
612 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
613 frame_metas.push(ZstdFrameMetadata {
614 uncompressed_size: uncompressed.len() as u64,
615 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
616 });
617 frames.push(ByteBuffer::from(compressed));
618 }
619
620 Ok(Frames {
621 dictionary,
622 frames,
623 frame_metas,
624 })
625 }
626
627 pub fn from_primitive(
634 parray: &PrimitiveArray,
635 level: i32,
636 values_per_frame: usize,
637 ctx: &mut ExecutionCtx,
638 ) -> VortexResult<Self> {
639 Self::from_primitive_impl(parray, level, values_per_frame, true, ctx)
640 }
641
642 pub fn from_primitive_without_dict(
656 parray: &PrimitiveArray,
657 level: i32,
658 values_per_frame: usize,
659 ctx: &mut ExecutionCtx,
660 ) -> VortexResult<Self> {
661 Self::from_primitive_impl(parray, level, values_per_frame, false, ctx)
662 }
663
664 fn from_primitive_impl(
665 parray: &PrimitiveArray,
666 level: i32,
667 values_per_frame: usize,
668 use_dictionary: bool,
669 ctx: &mut ExecutionCtx,
670 ) -> VortexResult<Self> {
671 let byte_width = parray.ptype().byte_width();
672
673 let values = collect_valid_primitive(parray, ctx)?;
675 let n_values = values.len();
676 let values_per_frame = if values_per_frame > 0 {
677 values_per_frame
678 } else {
679 n_values
680 };
681
682 let value_bytes = values.buffer_handle().try_to_host_sync()?;
683 let alignment = *value_bytes.alignment();
685 let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
686
687 let frame_byte_starts = (0..n_values * byte_width)
688 .step_by(step_width)
689 .collect::<Vec<_>>();
690 let Frames {
691 dictionary,
692 frames,
693 frame_metas,
694 } = Self::compress_values(
695 &value_bytes,
696 &frame_byte_starts,
697 level,
698 values_per_frame,
699 n_values,
700 use_dictionary,
701 )?;
702
703 let metadata = ZstdMetadata {
704 dictionary_size: dictionary
705 .as_ref()
706 .map_or(0, |dict| dict.len())
707 .try_into()?,
708 frames: frame_metas,
709 };
710
711 Ok(ZstdData::new(dictionary, frames, metadata, parray.len()))
712 }
713
714 pub fn from_var_bin_view(
721 vbv: &VarBinViewArray,
722 level: i32,
723 values_per_frame: usize,
724 ctx: &mut ExecutionCtx,
725 ) -> VortexResult<Self> {
726 Self::from_var_bin_view_impl(vbv, level, values_per_frame, true, ctx)
727 }
728
729 pub fn from_var_bin_view_without_dict(
743 vbv: &VarBinViewArray,
744 level: i32,
745 values_per_frame: usize,
746 ctx: &mut ExecutionCtx,
747 ) -> VortexResult<Self> {
748 Self::from_var_bin_view_impl(vbv, level, values_per_frame, false, ctx)
749 }
750
751 fn from_var_bin_view_impl(
752 vbv: &VarBinViewArray,
753 level: i32,
754 values_per_frame: usize,
755 use_dictionary: bool,
756 ctx: &mut ExecutionCtx,
757 ) -> VortexResult<Self> {
758 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv, ctx)?;
765 let n_values = value_byte_indices.len();
766 let values_per_frame = if values_per_frame > 0 {
767 values_per_frame
768 } else {
769 n_values
770 };
771
772 let frame_byte_starts = (0..n_values)
773 .step_by(values_per_frame)
774 .map(|i| value_byte_indices[i])
775 .collect::<Vec<_>>();
776 let Frames {
777 dictionary,
778 frames,
779 frame_metas,
780 } = Self::compress_values(
781 &value_bytes,
782 &frame_byte_starts,
783 level,
784 values_per_frame,
785 n_values,
786 use_dictionary,
787 )?;
788
789 let metadata = ZstdMetadata {
790 dictionary_size: dictionary
791 .as_ref()
792 .map_or(0, |dict| dict.len())
793 .try_into()?,
794 frames: frame_metas,
795 };
796 Ok(ZstdData::new(dictionary, frames, metadata, vbv.len()))
797 }
798
799 pub fn from_canonical(
800 canonical: &Canonical,
801 level: i32,
802 values_per_frame: usize,
803 ctx: &mut ExecutionCtx,
804 ) -> VortexResult<Option<Self>> {
805 match canonical {
806 Canonical::Primitive(parray) => Ok(Some(ZstdData::from_primitive(
807 parray,
808 level,
809 values_per_frame,
810 ctx,
811 )?)),
812 Canonical::VarBinView(vbv) => Ok(Some(ZstdData::from_var_bin_view(
813 vbv,
814 level,
815 values_per_frame,
816 ctx,
817 )?)),
818 _ => Ok(None),
819 }
820 }
821
822 pub fn from_array(
823 array: ArrayRef,
824 level: i32,
825 values_per_frame: usize,
826 ctx: &mut ExecutionCtx,
827 ) -> VortexResult<Self> {
828 let canonical = array.execute::<Canonical>(ctx)?;
829 Self::from_canonical(&canonical, level, values_per_frame, ctx)?
830 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
831 }
832
833 fn byte_width(dtype: &DType) -> usize {
834 if dtype.is_primitive() {
835 dtype.as_ptype().byte_width()
836 } else {
837 1
838 }
839 }
840
841 fn decompress(
842 &self,
843 dtype: &DType,
844 unsliced_validity: &Validity,
845 ctx: &mut ExecutionCtx,
846 ) -> VortexResult<ArrayRef> {
847 let byte_width = Self::byte_width(dtype);
850 let slice_n_rows = self.slice_stop - self.slice_start;
851 let slice_value_indices = unsliced_validity
852 .execute_mask(self.unsliced_n_rows, ctx)?
853 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
854
855 let slice_value_idx_start = slice_value_indices[0];
856 let slice_value_idx_stop = slice_value_indices[1];
857
858 let mut frames_to_decompress = vec![];
859 let mut value_idx_start = 0;
860 let mut uncompressed_size_to_decompress = 0;
861 let mut n_skipped_values = 0;
862 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
863 if value_idx_start >= slice_value_idx_stop {
864 break;
865 }
866
867 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
868 .vortex_expect("Uncompressed size must fit in usize");
869 let frame_n_values = if frame_meta.n_values == 0 {
870 frame_uncompressed_size / byte_width
872 } else {
873 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
874 };
875
876 let value_idx_stop = value_idx_start + frame_n_values;
877 if value_idx_stop > slice_value_idx_start {
878 frames_to_decompress.push(frame);
880 uncompressed_size_to_decompress += frame_uncompressed_size;
881 } else {
882 n_skipped_values += frame_n_values;
883 }
884 value_idx_start = value_idx_stop;
885 }
886
887 let mut decompressor = if let Some(dictionary) = &self.dictionary {
889 zstd::bulk::Decompressor::with_dictionary(dictionary)?
890 } else {
891 zstd::bulk::Decompressor::new()?
892 };
893 let mut decompressed = ByteBufferMut::with_capacity_aligned(
894 uncompressed_size_to_decompress,
895 Alignment::new(byte_width),
896 );
897 unsafe {
898 decompressed.set_len(uncompressed_size_to_decompress);
901 }
902 let mut uncompressed_start = 0;
903 for frame in frames_to_decompress {
904 let uncompressed_written = decompressor
905 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
906 uncompressed_start += uncompressed_written;
907 }
908 if uncompressed_start != uncompressed_size_to_decompress {
909 vortex_panic!(
910 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
911 uncompressed_size_to_decompress,
912 uncompressed_start
913 );
914 }
915
916 let decompressed = decompressed.freeze();
917 let mut slice_validity = unsliced_validity.slice(self.slice_start..self.slice_stop)?;
919
920 if !dtype.is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
929 assert!(
930 matches!(slice_validity, Validity::AllValid),
931 "ZSTD array expects to be non-nullable but there are nulls after decompression"
932 );
933
934 slice_validity = Validity::NonNullable;
935 } else if dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
936 slice_validity = Validity::AllValid;
937 }
938 match dtype {
943 DType::Primitive(..) => {
944 let slice_values_buffer = decompressed.slice(
945 (slice_value_idx_start - n_skipped_values) * byte_width
946 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
947 );
948 let primitive = PrimitiveArray::from_values_byte_buffer(
949 slice_values_buffer,
950 dtype.as_ptype(),
951 slice_validity,
952 slice_n_rows,
953 );
954
955 Ok(primitive.into_array())
956 }
957 DType::Binary(_) | DType::Utf8(_) => {
958 match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
959 AllOr::All => {
960 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
961 let valid_views = all_views.slice(
962 slice_value_idx_start - n_skipped_values
963 ..slice_value_idx_stop - n_skipped_values,
964 );
965
966 Ok(unsafe {
968 VarBinViewArray::new_unchecked(
969 valid_views,
970 Arc::from(buffers),
971 dtype.clone(),
972 slice_validity,
973 )
974 }
975 .into_array())
976 }
977 AllOr::None => Ok(ConstantArray::new(
978 Scalar::null(dtype.clone()),
979 slice_n_rows,
980 )
981 .into_array()),
982 AllOr::Some(valid_indices) => {
983 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
984 let valid_views = all_views.slice(
985 slice_value_idx_start - n_skipped_values
986 ..slice_value_idx_stop - n_skipped_values,
987 );
988
989 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
990 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
991 views[*index] = view
992 }
993
994 Ok(unsafe {
996 VarBinViewArray::new_unchecked(
997 views.freeze(),
998 Arc::from(buffers),
999 dtype.clone(),
1000 slice_validity,
1001 )
1002 }
1003 .into_array())
1004 }
1005 }
1006 }
1007 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", dtype),
1008 }
1009 }
1010
1011 #[inline]
1013 pub fn len(&self) -> usize {
1014 self.slice_stop - self.slice_start
1015 }
1016
1017 #[inline]
1019 pub fn is_empty(&self) -> bool {
1020 self.slice_stop == self.slice_start
1021 }
1022
1023 pub fn into_parts(self, validity: Validity) -> ZstdDataParts {
1024 ZstdDataParts {
1025 dictionary: self.dictionary,
1026 frames: self.frames,
1027 metadata: self.metadata,
1028 validity,
1029 n_rows: self.unsliced_n_rows,
1030 slice_start: self.slice_start,
1031 slice_stop: self.slice_stop,
1032 }
1033 }
1034
1035 pub(crate) fn slice_start(&self) -> usize {
1036 self.slice_start
1037 }
1038
1039 pub(crate) fn slice_stop(&self) -> usize {
1040 self.slice_stop
1041 }
1042
1043 pub(crate) fn unsliced_n_rows(&self) -> usize {
1044 self.unsliced_n_rows
1045 }
1046}
1047
1048impl ValidityVTable<Zstd> for Zstd {
1049 fn validity(array: ArrayView<'_, Zstd>) -> VortexResult<Validity> {
1050 let unsliced_validity =
1051 child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1052 unsliced_validity.slice(array.slice_start()..array.slice_stop())
1053 }
1054}
1055
1056impl OperationsVTable<Zstd> for Zstd {
1057 fn scalar_at(
1058 array: ArrayView<'_, Zstd>,
1059 index: usize,
1060 ctx: &mut ExecutionCtx,
1061 ) -> VortexResult<Scalar> {
1062 let unsliced_validity =
1063 child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1064 let sliced = array.data().with_slice(index, index + 1);
1065 sliced
1066 .decompress(array.dtype(), &unsliced_validity, ctx)?
1067 .execute_scalar(0, ctx)
1068 }
1069}
1070
1071#[cfg(test)]
1072#[expect(clippy::cast_possible_truncation)]
1073mod tests {
1074 use vortex_buffer::ByteBuffer;
1075
1076 use super::reconstruct_views;
1077 use crate::array::BinaryView;
1078
1079 fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
1081 let mut buf = Vec::new();
1082 for s in strings {
1083 let len = s.len() as u32;
1084 buf.extend_from_slice(&len.to_le_bytes());
1085 buf.extend_from_slice(s);
1086 }
1087 ByteBuffer::copy_from(buf.as_slice())
1088 }
1089
1090 #[test]
1091 fn test_reconstruct_views_no_split() {
1092 let strings: &[&[u8]] = &[b"hello", b"world"];
1093 let buf = make_interleaved(strings);
1094 let (buffers, views) = reconstruct_views(&buf, 1024);
1095
1096 assert_eq!(buffers.len(), 1);
1097 assert_eq!(views.len(), 2);
1098 assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1100 assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1101 }
1102
1103 #[test]
1104 fn test_reconstruct_views_split_across_segments() {
1105 let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1110 let buf = make_interleaved(strings);
1111 let (buffers, views) = reconstruct_views(&buf, 20);
1112
1113 assert_eq!(buffers.len(), 2);
1114 assert_eq!(views.len(), 2);
1115 assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1116 assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1118 }
1119}