1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use itertools::Itertools as _;
12use prost::Message as _;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayId;
17use vortex_array::ArrayParts;
18use vortex_array::ArrayRef;
19use vortex_array::ArrayView;
20use vortex_array::Canonical;
21use vortex_array::ExecutionCtx;
22use vortex_array::ExecutionResult;
23use vortex_array::IntoArray;
24use vortex_array::Precision;
25use vortex_array::accessor::ArrayAccessor;
26use vortex_array::arrays::ConstantArray;
27use vortex_array::arrays::PrimitiveArray;
28use vortex_array::arrays::VarBinViewArray;
29use vortex_array::arrays::varbinview::build_views::BinaryView;
30use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
31use vortex_array::buffer::BufferHandle;
32use vortex_array::dtype::DType;
33use vortex_array::scalar::Scalar;
34use vortex_array::serde::ArrayChildren;
35use vortex_array::validity::Validity;
36use vortex_array::vtable::OperationsVTable;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityVTable;
39use vortex_array::vtable::child_to_validity;
40use vortex_array::vtable::validity_to_child;
41use vortex_buffer::Alignment;
42use vortex_buffer::Buffer;
43use vortex_buffer::BufferMut;
44use vortex_buffer::ByteBuffer;
45use vortex_buffer::ByteBufferMut;
46use vortex_error::VortexError;
47use vortex_error::VortexExpect;
48use vortex_error::VortexResult;
49use vortex_error::vortex_bail;
50use vortex_error::vortex_ensure;
51use vortex_error::vortex_err;
52use vortex_error::vortex_panic;
53use vortex_mask::AllOr;
54use vortex_session::VortexSession;
55use vortex_session::registry::CachedId;
56
57use crate::ZstdFrameMetadata;
58use crate::ZstdMetadata;
59
60const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
62type ViewLen = u32;
63
64pub type ZstdArray = Array<Zstd>;
84
85impl ArrayHash for ZstdData {
86 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
87 match &self.dictionary {
88 Some(dict) => {
89 true.hash(state);
90 dict.array_hash(state, precision);
91 }
92 None => {
93 false.hash(state);
94 }
95 }
96 for frame in &self.frames {
97 frame.array_hash(state, precision);
98 }
99 self.unsliced_n_rows.hash(state);
100 self.slice_start.hash(state);
101 self.slice_stop.hash(state);
102 }
103}
104
105impl ArrayEq for ZstdData {
106 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
107 if !match (&self.dictionary, &other.dictionary) {
108 (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
109 (None, None) => true,
110 _ => false,
111 } {
112 return false;
113 }
114 if self.frames.len() != other.frames.len() {
115 return false;
116 }
117 for (a, b) in self.frames.iter().zip(&other.frames) {
118 if !a.array_eq(b, precision) {
119 return false;
120 }
121 }
122 self.unsliced_n_rows == other.unsliced_n_rows
123 && self.slice_start == other.slice_start
124 && self.slice_stop == other.slice_stop
125 }
126}
127
128impl VTable for Zstd {
129 type ArrayData = ZstdData;
130
131 type OperationsVTable = Self;
132 type ValidityVTable = Self;
133
134 fn id(&self) -> ArrayId {
135 static ID: CachedId = CachedId::new("vortex.zstd");
136 *ID
137 }
138
139 fn validate(
140 &self,
141 data: &Self::ArrayData,
142 dtype: &DType,
143 len: usize,
144 slots: &[Option<ArrayRef>],
145 ) -> VortexResult<()> {
146 let validity = child_to_validity(slots[0].as_ref(), dtype.nullability());
147 data.validate(dtype, len, &validity)
148 }
149
150 fn nbuffers(array: ArrayView<'_, Self>) -> usize {
151 array.dictionary.is_some() as usize + array.frames.len()
152 }
153
154 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
155 if let Some(dict) = &array.dictionary {
156 if idx == 0 {
157 return BufferHandle::new_host(dict.clone());
158 }
159 BufferHandle::new_host(array.frames[idx - 1].clone())
160 } else {
161 BufferHandle::new_host(array.frames[idx].clone())
162 }
163 }
164
165 fn buffer_name(array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
166 if array.dictionary.is_some() {
167 if idx == 0 {
168 Some("dictionary".to_string())
169 } else {
170 Some(format!("frame_{}", idx - 1))
171 }
172 } else {
173 Some(format!("frame_{idx}"))
174 }
175 }
176
177 fn serialize(
178 array: ArrayView<'_, Self>,
179 _session: &VortexSession,
180 ) -> VortexResult<Option<Vec<u8>>> {
181 Ok(Some(array.metadata.clone().encode_to_vec()))
182 }
183
184 fn deserialize(
185 &self,
186 dtype: &DType,
187 len: usize,
188 metadata: &[u8],
189 buffers: &[BufferHandle],
190 children: &dyn ArrayChildren,
191 _session: &VortexSession,
192 ) -> VortexResult<ArrayParts<Self>> {
193 let metadata = ZstdMetadata::decode(metadata)?;
194 let validity = if children.is_empty() {
195 Validity::from(dtype.nullability())
196 } else if children.len() == 1 {
197 let validity = children.get(0, &Validity::DTYPE, len)?;
198 Validity::Array(validity)
199 } else {
200 vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
201 };
202
203 let (dictionary_buffer, compressed_buffers) = if metadata.dictionary_size == 0 {
204 (
206 None,
207 buffers
208 .iter()
209 .map(|b| b.clone().try_to_host_sync())
210 .collect::<VortexResult<Vec<_>>>()?,
211 )
212 } else {
213 (
215 Some(buffers[0].clone().try_to_host_sync()?),
216 buffers[1..]
217 .iter()
218 .map(|b| b.clone().try_to_host_sync())
219 .collect::<VortexResult<Vec<_>>>()?,
220 )
221 };
222
223 let slots = vec![validity_to_child(&validity, len)];
224 let data = ZstdData::new(dictionary_buffer, compressed_buffers, metadata, len);
225 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
226 }
227
228 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
229 SLOT_NAMES[idx].to_string()
230 }
231
232 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
233 let unsliced_validity = child_to_validity(
234 array.as_ref().slots()[0].as_ref(),
235 array.dtype().nullability(),
236 );
237 array
238 .data()
239 .decompress(array.dtype(), &unsliced_validity, ctx)?
240 .execute::<ArrayRef>(ctx)
241 .map(ExecutionResult::done)
242 }
243
244 fn reduce_parent(
245 array: ArrayView<'_, Self>,
246 parent: &ArrayRef,
247 child_idx: usize,
248 ) -> VortexResult<Option<ArrayRef>> {
249 crate::rules::RULES.evaluate(array, parent, child_idx)
250 }
251}
252
253#[derive(Clone, Debug)]
254pub struct Zstd;
255
256impl Zstd {
257 pub fn try_new(dtype: DType, data: ZstdData, validity: Validity) -> VortexResult<ZstdArray> {
258 let len = data.len();
259 data.validate(&dtype, len, &validity)?;
260 let slots = vec![validity_to_child(&validity, data.unsliced_n_rows())];
261 Ok(unsafe {
262 Array::from_parts_unchecked(ArrayParts::new(Zstd, dtype, len, data).with_slots(slots))
263 })
264 }
265
266 pub fn from_var_bin_view_without_dict(
268 vbv: &VarBinViewArray,
269 level: i32,
270 values_per_frame: usize,
271 ctx: &mut ExecutionCtx,
272 ) -> VortexResult<ZstdArray> {
273 let validity = vbv.validity()?;
274 Self::try_new(
275 vbv.dtype().clone(),
276 ZstdData::from_var_bin_view_without_dict(vbv, level, values_per_frame, ctx)?,
277 validity,
278 )
279 }
280
281 pub fn from_primitive(
283 parray: &PrimitiveArray,
284 level: i32,
285 values_per_frame: usize,
286 ctx: &mut ExecutionCtx,
287 ) -> VortexResult<ZstdArray> {
288 let validity = parray.validity()?;
289 Self::try_new(
290 parray.dtype().clone(),
291 ZstdData::from_primitive(parray, level, values_per_frame, ctx)?,
292 validity,
293 )
294 }
295
296 pub fn from_var_bin_view(
298 vbv: &VarBinViewArray,
299 level: i32,
300 values_per_frame: usize,
301 ctx: &mut ExecutionCtx,
302 ) -> VortexResult<ZstdArray> {
303 let validity = vbv.validity()?;
304 Self::try_new(
305 vbv.dtype().clone(),
306 ZstdData::from_var_bin_view(vbv, level, values_per_frame, ctx)?,
307 validity,
308 )
309 }
310
311 pub fn decompress(array: &ZstdArray, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
312 let unsliced_validity = child_to_validity(
313 array.as_ref().slots()[0].as_ref(),
314 array.dtype().nullability(),
315 );
316 array
317 .data()
318 .decompress(array.dtype(), &unsliced_validity, ctx)
319 }
320}
321
322pub(super) const NUM_SLOTS: usize = 1;
324pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
325
326#[derive(Clone, Debug)]
327pub struct ZstdData {
328 pub(crate) dictionary: Option<ByteBuffer>,
329 pub(crate) frames: Vec<ByteBuffer>,
330 pub(crate) metadata: ZstdMetadata,
331 unsliced_n_rows: usize,
332 slice_start: usize,
333 slice_stop: usize,
334}
335
336impl Display for ZstdData {
337 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
338 write!(
339 f,
340 "nrows: {}, slice: {}..{}",
341 self.unsliced_n_rows, self.slice_start, self.slice_stop
342 )
343 }
344}
345
346pub struct ZstdDataParts {
347 pub dictionary: Option<ByteBuffer>,
348 pub frames: Vec<ByteBuffer>,
349 pub metadata: ZstdMetadata,
350 pub validity: Validity,
351 pub n_rows: usize,
352 pub slice_start: usize,
353 pub slice_stop: usize,
354}
355
356#[derive(Debug)]
358struct Frames {
359 dictionary: Option<ByteBuffer>,
360 frames: Vec<ByteBuffer>,
361 frame_metas: Vec<ZstdFrameMetadata>,
362}
363
364fn choose_max_dict_size(uncompressed_size: usize) -> usize {
365 (uncompressed_size / 100).clamp(256, 100 * 1024)
370}
371
372fn collect_valid_primitive(
373 parray: &PrimitiveArray,
374 ctx: &mut ExecutionCtx,
375) -> VortexResult<PrimitiveArray> {
376 let mask = parray
377 .as_ref()
378 .validity()?
379 .execute_mask(parray.as_ref().len(), ctx)?;
380 let result = parray.filter(mask)?.execute::<PrimitiveArray>(ctx)?;
381 Ok(result)
382}
383
384fn collect_valid_vbv(
385 vbv: &VarBinViewArray,
386 ctx: &mut ExecutionCtx,
387) -> VortexResult<(ByteBuffer, Vec<usize>)> {
388 let mask = vbv
389 .as_ref()
390 .validity()?
391 .execute_mask(vbv.as_ref().len(), ctx)?;
392 let buffer_and_value_byte_indices = match mask.bit_buffer() {
393 AllOr::None => (Buffer::empty(), Vec::new()),
394 _ => {
395 let mut buffer = BufferMut::with_capacity(
396 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
397 + mask.true_count() * size_of::<ViewLen>(),
398 );
399 let mut value_byte_indices = Vec::new();
400 vbv.with_iterator(|iterator| {
401 for value in iterator.flatten() {
403 value_byte_indices.push(buffer.len());
404 buffer
406 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
407 buffer.extend_from_slice(value);
408 }
409 Ok::<_, VortexError>(())
410 })?;
411 (buffer.freeze(), value_byte_indices)
412 }
413 };
414 Ok(buffer_and_value_byte_indices)
415}
416
417pub fn reconstruct_views(
426 buffer: &ByteBuffer,
427 max_buffer_len: usize,
428) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
429 let mut views = BufferMut::<BinaryView>::empty();
430 let mut buffers = Vec::new();
431 let mut segment_start: usize = 0;
432 let mut offset = 0;
433
434 while offset < buffer.len() {
435 let str_len = ViewLen::from_le_bytes(
436 buffer
437 .get(offset..offset + size_of::<ViewLen>())
438 .vortex_expect("corrupted zstd length")
439 .try_into()
440 .ok()
441 .vortex_expect("must fit ViewLen size"),
442 ) as usize;
443
444 let value_data_offset = offset + size_of::<ViewLen>();
445 let local_offset = value_data_offset - segment_start;
446
447 if local_offset + str_len > max_buffer_len && offset > segment_start {
448 buffers.push(buffer.slice(segment_start..offset));
449 segment_start = offset;
450 }
451
452 let local_offset = u32::try_from(value_data_offset - segment_start)
453 .vortex_expect("local offset within segment must fit in u32");
454 let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
455 let value = &buffer[value_data_offset..value_data_offset + str_len];
456 views.push(BinaryView::make_view(value, buf_index, local_offset));
457 offset = value_data_offset + str_len;
458 }
459
460 if segment_start < buffer.len() {
461 buffers.push(buffer.slice(segment_start..buffer.len()));
462 }
463
464 (buffers, views.freeze())
465}
466
467impl ZstdData {
468 pub fn new(
469 dictionary: Option<ByteBuffer>,
470 frames: Vec<ByteBuffer>,
471 metadata: ZstdMetadata,
472 n_rows: usize,
473 ) -> Self {
474 Self {
475 dictionary,
476 frames,
477 metadata,
478 unsliced_n_rows: n_rows,
479 slice_start: 0,
480 slice_stop: n_rows,
481 }
482 }
483
484 pub fn validate(&self, dtype: &DType, len: usize, validity: &Validity) -> VortexResult<()> {
485 vortex_ensure!(
486 matches!(
487 dtype,
488 DType::Primitive(..) | DType::Binary(_) | DType::Utf8(_)
489 ),
490 "Unsupported dtype for Zstd array: {dtype}"
491 );
492 vortex_ensure!(
493 self.slice_start <= self.slice_stop,
494 "Invalid slice range {}..{}",
495 self.slice_start,
496 self.slice_stop
497 );
498 vortex_ensure!(
499 self.slice_stop <= self.unsliced_n_rows,
500 "Slice stop {} exceeds unsliced row count {}",
501 self.slice_stop,
502 self.unsliced_n_rows
503 );
504 vortex_ensure!(
505 self.slice_stop - self.slice_start == len,
506 "Slice length {} does not match array length {}",
507 self.slice_stop - self.slice_start,
508 len
509 );
510 if let Some(validity_len) = validity.maybe_len() {
511 vortex_ensure!(
512 validity_len == self.unsliced_n_rows,
513 "Validity length {} does not match unsliced row count {}",
514 validity_len,
515 self.unsliced_n_rows
516 );
517 }
518
519 match &self.dictionary {
520 Some(dictionary) => vortex_ensure!(
521 usize::try_from(self.metadata.dictionary_size)? == dictionary.len(),
522 "Dictionary size metadata {} does not match buffer size {}",
523 self.metadata.dictionary_size,
524 dictionary.len()
525 ),
526 None => vortex_ensure!(
527 self.metadata.dictionary_size == 0,
528 "Dictionary metadata present without dictionary buffer"
529 ),
530 }
531 vortex_ensure!(
532 self.frames.len() == self.metadata.frames.len(),
533 "Frame count {} does not match metadata frame count {}",
534 self.frames.len(),
535 self.metadata.frames.len()
536 );
537
538 Ok(())
539 }
540
541 pub(crate) fn with_slice(&self, start: usize, stop: usize) -> Self {
542 let new_start = self.slice_start + start;
543 let new_stop = self.slice_start + stop;
544
545 assert!(
546 new_start <= self.slice_stop,
547 "new slice start {new_start} exceeds end {}",
548 self.slice_stop
549 );
550
551 assert!(
552 new_stop <= self.slice_stop,
553 "new slice stop {new_stop} exceeds end {}",
554 self.slice_stop
555 );
556
557 Self {
558 slice_start: new_start,
559 slice_stop: new_stop,
560 ..self.clone()
561 }
562 }
563
564 fn compress_values(
565 value_bytes: &ByteBuffer,
566 frame_byte_starts: &[usize],
567 level: i32,
568 values_per_frame: usize,
569 n_values: usize,
570 use_dictionary: bool,
571 ) -> VortexResult<Frames> {
572 let n_frames = frame_byte_starts.len();
573
574 let mut sample_sizes = Vec::with_capacity(n_frames);
576 for i in 0..n_frames {
577 let frame_byte_end = frame_byte_starts
578 .get(i + 1)
579 .copied()
580 .unwrap_or(value_bytes.len());
581 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
582 }
583 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
584
585 let (dictionary, mut compressor) = if !use_dictionary
586 || sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
587 {
588 (None, zstd::bulk::Compressor::new(level)?)
590 } else {
591 let max_dict_size = choose_max_dict_size(value_bytes.len());
593 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
594 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
595
596 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
597 (Some(ByteBuffer::from(dict)), compressor)
598 };
599
600 let mut frame_metas = vec![];
601 let mut frames = vec![];
602 for i in 0..n_frames {
603 let frame_byte_end = frame_byte_starts
604 .get(i + 1)
605 .copied()
606 .unwrap_or(value_bytes.len());
607
608 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
609 let compressed = compressor
610 .compress(uncompressed)
611 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
612 frame_metas.push(ZstdFrameMetadata {
613 uncompressed_size: uncompressed.len() as u64,
614 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
615 });
616 frames.push(ByteBuffer::from(compressed));
617 }
618
619 Ok(Frames {
620 dictionary,
621 frames,
622 frame_metas,
623 })
624 }
625
626 pub fn from_primitive(
633 parray: &PrimitiveArray,
634 level: i32,
635 values_per_frame: usize,
636 ctx: &mut ExecutionCtx,
637 ) -> VortexResult<Self> {
638 Self::from_primitive_impl(parray, level, values_per_frame, true, ctx)
639 }
640
641 pub fn from_primitive_without_dict(
655 parray: &PrimitiveArray,
656 level: i32,
657 values_per_frame: usize,
658 ctx: &mut ExecutionCtx,
659 ) -> VortexResult<Self> {
660 Self::from_primitive_impl(parray, level, values_per_frame, false, ctx)
661 }
662
663 fn from_primitive_impl(
664 parray: &PrimitiveArray,
665 level: i32,
666 values_per_frame: usize,
667 use_dictionary: bool,
668 ctx: &mut ExecutionCtx,
669 ) -> VortexResult<Self> {
670 let byte_width = parray.ptype().byte_width();
671
672 let values = collect_valid_primitive(parray, ctx)?;
674 let n_values = values.len();
675 let values_per_frame = if values_per_frame > 0 {
676 values_per_frame
677 } else {
678 n_values
679 };
680
681 let value_bytes = values.buffer_handle().try_to_host_sync()?;
682 let alignment = *value_bytes.alignment();
684 let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
685
686 let frame_byte_starts = (0..n_values * byte_width)
687 .step_by(step_width)
688 .collect::<Vec<_>>();
689 let Frames {
690 dictionary,
691 frames,
692 frame_metas,
693 } = Self::compress_values(
694 &value_bytes,
695 &frame_byte_starts,
696 level,
697 values_per_frame,
698 n_values,
699 use_dictionary,
700 )?;
701
702 let metadata = ZstdMetadata {
703 dictionary_size: dictionary
704 .as_ref()
705 .map_or(0, |dict| dict.len())
706 .try_into()?,
707 frames: frame_metas,
708 };
709
710 Ok(ZstdData::new(dictionary, frames, metadata, parray.len()))
711 }
712
713 pub fn from_var_bin_view(
720 vbv: &VarBinViewArray,
721 level: i32,
722 values_per_frame: usize,
723 ctx: &mut ExecutionCtx,
724 ) -> VortexResult<Self> {
725 Self::from_var_bin_view_impl(vbv, level, values_per_frame, true, ctx)
726 }
727
728 pub fn from_var_bin_view_without_dict(
742 vbv: &VarBinViewArray,
743 level: i32,
744 values_per_frame: usize,
745 ctx: &mut ExecutionCtx,
746 ) -> VortexResult<Self> {
747 Self::from_var_bin_view_impl(vbv, level, values_per_frame, false, ctx)
748 }
749
750 fn from_var_bin_view_impl(
751 vbv: &VarBinViewArray,
752 level: i32,
753 values_per_frame: usize,
754 use_dictionary: bool,
755 ctx: &mut ExecutionCtx,
756 ) -> VortexResult<Self> {
757 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv, ctx)?;
764 let n_values = value_byte_indices.len();
765 let values_per_frame = if values_per_frame > 0 {
766 values_per_frame
767 } else {
768 n_values
769 };
770
771 let frame_byte_starts = (0..n_values)
772 .step_by(values_per_frame)
773 .map(|i| value_byte_indices[i])
774 .collect::<Vec<_>>();
775 let Frames {
776 dictionary,
777 frames,
778 frame_metas,
779 } = Self::compress_values(
780 &value_bytes,
781 &frame_byte_starts,
782 level,
783 values_per_frame,
784 n_values,
785 use_dictionary,
786 )?;
787
788 let metadata = ZstdMetadata {
789 dictionary_size: dictionary
790 .as_ref()
791 .map_or(0, |dict| dict.len())
792 .try_into()?,
793 frames: frame_metas,
794 };
795 Ok(ZstdData::new(dictionary, frames, metadata, vbv.len()))
796 }
797
798 pub fn from_canonical(
799 canonical: &Canonical,
800 level: i32,
801 values_per_frame: usize,
802 ctx: &mut ExecutionCtx,
803 ) -> VortexResult<Option<Self>> {
804 match canonical {
805 Canonical::Primitive(parray) => Ok(Some(ZstdData::from_primitive(
806 parray,
807 level,
808 values_per_frame,
809 ctx,
810 )?)),
811 Canonical::VarBinView(vbv) => Ok(Some(ZstdData::from_var_bin_view(
812 vbv,
813 level,
814 values_per_frame,
815 ctx,
816 )?)),
817 _ => Ok(None),
818 }
819 }
820
821 pub fn from_array(
822 array: ArrayRef,
823 level: i32,
824 values_per_frame: usize,
825 ctx: &mut ExecutionCtx,
826 ) -> VortexResult<Self> {
827 let canonical = array.execute::<Canonical>(ctx)?;
828 Self::from_canonical(&canonical, level, values_per_frame, ctx)?
829 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
830 }
831
832 fn byte_width(dtype: &DType) -> usize {
833 if dtype.is_primitive() {
834 dtype.as_ptype().byte_width()
835 } else {
836 1
837 }
838 }
839
840 fn decompress(
841 &self,
842 dtype: &DType,
843 unsliced_validity: &Validity,
844 ctx: &mut ExecutionCtx,
845 ) -> VortexResult<ArrayRef> {
846 let byte_width = Self::byte_width(dtype);
849 let slice_n_rows = self.slice_stop - self.slice_start;
850 let slice_value_indices = unsliced_validity
851 .execute_mask(self.unsliced_n_rows, ctx)?
852 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
853
854 let slice_value_idx_start = slice_value_indices[0];
855 let slice_value_idx_stop = slice_value_indices[1];
856
857 let mut frames_to_decompress = vec![];
858 let mut value_idx_start = 0;
859 let mut uncompressed_size_to_decompress = 0;
860 let mut n_skipped_values = 0;
861 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
862 if value_idx_start >= slice_value_idx_stop {
863 break;
864 }
865
866 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
867 .vortex_expect("Uncompressed size must fit in usize");
868 let frame_n_values = if frame_meta.n_values == 0 {
869 frame_uncompressed_size / byte_width
871 } else {
872 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
873 };
874
875 let value_idx_stop = value_idx_start + frame_n_values;
876 if value_idx_stop > slice_value_idx_start {
877 frames_to_decompress.push(frame);
879 uncompressed_size_to_decompress += frame_uncompressed_size;
880 } else {
881 n_skipped_values += frame_n_values;
882 }
883 value_idx_start = value_idx_stop;
884 }
885
886 let mut decompressor = if let Some(dictionary) = &self.dictionary {
888 zstd::bulk::Decompressor::with_dictionary(dictionary)?
889 } else {
890 zstd::bulk::Decompressor::new()?
891 };
892 let mut decompressed = ByteBufferMut::with_capacity_aligned(
893 uncompressed_size_to_decompress,
894 Alignment::new(byte_width),
895 );
896 unsafe {
897 decompressed.set_len(uncompressed_size_to_decompress);
900 }
901 let mut uncompressed_start = 0;
902 for frame in frames_to_decompress {
903 let uncompressed_written = decompressor
904 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
905 uncompressed_start += uncompressed_written;
906 }
907 if uncompressed_start != uncompressed_size_to_decompress {
908 vortex_panic!(
909 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
910 uncompressed_size_to_decompress,
911 uncompressed_start
912 );
913 }
914
915 let decompressed = decompressed.freeze();
916 let mut slice_validity = unsliced_validity.slice(self.slice_start..self.slice_stop)?;
918
919 if !dtype.is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
928 assert!(
929 matches!(slice_validity, Validity::AllValid),
930 "ZSTD array expects to be non-nullable but there are nulls after decompression"
931 );
932
933 slice_validity = Validity::NonNullable;
934 } else if dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
935 slice_validity = Validity::AllValid;
936 }
937 match dtype {
942 DType::Primitive(..) => {
943 let slice_values_buffer = decompressed.slice(
944 (slice_value_idx_start - n_skipped_values) * byte_width
945 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
946 );
947 let primitive = PrimitiveArray::from_values_byte_buffer(
948 slice_values_buffer,
949 dtype.as_ptype(),
950 slice_validity,
951 slice_n_rows,
952 );
953
954 Ok(primitive.into_array())
955 }
956 DType::Binary(_) | DType::Utf8(_) => {
957 match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
958 AllOr::All => {
959 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
960 let valid_views = all_views.slice(
961 slice_value_idx_start - n_skipped_values
962 ..slice_value_idx_stop - n_skipped_values,
963 );
964
965 Ok(unsafe {
967 VarBinViewArray::new_unchecked(
968 valid_views,
969 Arc::from(buffers),
970 dtype.clone(),
971 slice_validity,
972 )
973 }
974 .into_array())
975 }
976 AllOr::None => Ok(ConstantArray::new(
977 Scalar::null(dtype.clone()),
978 slice_n_rows,
979 )
980 .into_array()),
981 AllOr::Some(valid_indices) => {
982 let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
983 let valid_views = all_views.slice(
984 slice_value_idx_start - n_skipped_values
985 ..slice_value_idx_stop - n_skipped_values,
986 );
987
988 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
989 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
990 views[*index] = view
991 }
992
993 Ok(unsafe {
995 VarBinViewArray::new_unchecked(
996 views.freeze(),
997 Arc::from(buffers),
998 dtype.clone(),
999 slice_validity,
1000 )
1001 }
1002 .into_array())
1003 }
1004 }
1005 }
1006 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", dtype),
1007 }
1008 }
1009
1010 #[inline]
1012 pub fn len(&self) -> usize {
1013 self.slice_stop - self.slice_start
1014 }
1015
1016 #[inline]
1018 pub fn is_empty(&self) -> bool {
1019 self.slice_stop == self.slice_start
1020 }
1021
1022 pub fn into_parts(self, validity: Validity) -> ZstdDataParts {
1023 ZstdDataParts {
1024 dictionary: self.dictionary,
1025 frames: self.frames,
1026 metadata: self.metadata,
1027 validity,
1028 n_rows: self.unsliced_n_rows,
1029 slice_start: self.slice_start,
1030 slice_stop: self.slice_stop,
1031 }
1032 }
1033
1034 pub(crate) fn slice_start(&self) -> usize {
1035 self.slice_start
1036 }
1037
1038 pub(crate) fn slice_stop(&self) -> usize {
1039 self.slice_stop
1040 }
1041
1042 pub(crate) fn unsliced_n_rows(&self) -> usize {
1043 self.unsliced_n_rows
1044 }
1045}
1046
1047impl ValidityVTable<Zstd> for Zstd {
1048 fn validity(array: ArrayView<'_, Zstd>) -> VortexResult<Validity> {
1049 let unsliced_validity =
1050 child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1051 unsliced_validity.slice(array.slice_start()..array.slice_stop())
1052 }
1053}
1054
1055impl OperationsVTable<Zstd> for Zstd {
1056 fn scalar_at(
1057 array: ArrayView<'_, Zstd>,
1058 index: usize,
1059 ctx: &mut ExecutionCtx,
1060 ) -> VortexResult<Scalar> {
1061 let unsliced_validity =
1062 child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
1063 let sliced = array.data().with_slice(index, index + 1);
1064 sliced
1065 .decompress(array.dtype(), &unsliced_validity, ctx)?
1066 .execute_scalar(0, ctx)
1067 }
1068}
1069
1070#[cfg(test)]
1071#[expect(clippy::cast_possible_truncation)]
1072mod tests {
1073 use vortex_buffer::ByteBuffer;
1074
1075 use super::reconstruct_views;
1076 use crate::array::BinaryView;
1077
1078 fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
1080 let mut buf = Vec::new();
1081 for s in strings {
1082 let len = s.len() as u32;
1083 buf.extend_from_slice(&len.to_le_bytes());
1084 buf.extend_from_slice(s);
1085 }
1086 ByteBuffer::copy_from(buf.as_slice())
1087 }
1088
1089 #[test]
1090 fn test_reconstruct_views_no_split() {
1091 let strings: &[&[u8]] = &[b"hello", b"world"];
1092 let buf = make_interleaved(strings);
1093 let (buffers, views) = reconstruct_views(&buf, 1024);
1094
1095 assert_eq!(buffers.len(), 1);
1096 assert_eq!(views.len(), 2);
1097 assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
1099 assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
1100 }
1101
1102 #[test]
1103 fn test_reconstruct_views_split_across_segments() {
1104 let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
1109 let buf = make_interleaved(strings);
1110 let (buffers, views) = reconstruct_views(&buf, 20);
1111
1112 assert_eq!(buffers.len(), 2);
1113 assert_eq!(views.len(), 2);
1114 assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
1115 assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
1117 }
1118}