1use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::Range;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_array::accessor::ArrayAccessor;
11use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinViewArray};
12use vortex_array::compute::filter;
13use vortex_array::stats::{ArrayStats, StatsSetRef};
14use vortex_array::validity::Validity;
15use vortex_array::vtable::{
16 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
17 ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
18};
19use vortex_array::{
20 ArrayEq, ArrayHash, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, Precision,
21 ToCanonical, vtable,
22};
23use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
24use vortex_dtype::DType;
25use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
26use vortex_mask::AllOr;
27use vortex_scalar::Scalar;
28use vortex_vector::binaryview::BinaryView;
29
30use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
31
32const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
34type ViewLen = u32;
35
36vtable!(Zstd);
55
56impl VTable for ZstdVTable {
57 type Array = ZstdArray;
58 type Encoding = ZstdEncoding;
59
60 type ArrayVTable = Self;
61 type CanonicalVTable = Self;
62 type OperationsVTable = Self;
63 type ValidityVTable = ValidityVTableFromValiditySliceHelper;
64 type VisitorVTable = Self;
65 type ComputeVTable = NotSupported;
66 type EncodeVTable = Self;
67 type SerdeVTable = Self;
68 type OperatorVTable = NotSupported;
69
70 fn id(_encoding: &Self::Encoding) -> EncodingId {
71 EncodingId::new_ref("vortex.zstd")
72 }
73
74 fn encoding(_array: &Self::Array) -> EncodingRef {
75 EncodingRef::new_ref(ZstdEncoding.as_ref())
76 }
77}
78
79#[derive(Clone, Debug)]
80pub struct ZstdEncoding;
81
82#[derive(Clone, Debug)]
83pub struct ZstdArray {
84 pub(crate) dictionary: Option<ByteBuffer>,
85 pub(crate) frames: Vec<ByteBuffer>,
86 pub(crate) metadata: ZstdMetadata,
87 dtype: DType,
88 pub(crate) unsliced_validity: Validity,
89 unsliced_n_rows: usize,
90 stats_set: ArrayStats,
91 slice_start: usize,
92 slice_stop: usize,
93}
94
95struct Frames {
96 dictionary: Option<ByteBuffer>,
97 frames: Vec<ByteBuffer>,
98 frame_metas: Vec<ZstdFrameMetadata>,
99}
100
101fn choose_max_dict_size(uncompressed_size: usize) -> usize {
102 (uncompressed_size / 100).clamp(256, 100 * 1024)
107}
108
109fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
110 let mask = parray.validity_mask();
111 Ok(filter(&parray.to_array(), &mask)?.to_primitive())
112}
113
114fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
115 let mask = vbv.validity_mask();
116 let buffer_and_value_byte_indices = match mask.bit_buffer() {
117 AllOr::None => (Buffer::empty(), Vec::new()),
118 _ => {
119 let mut buffer = BufferMut::with_capacity(
120 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
121 + mask.true_count() * size_of::<ViewLen>(),
122 );
123 let mut value_byte_indices = Vec::new();
124 vbv.with_iterator(|iterator| {
125 for value in iterator.flatten() {
127 value_byte_indices.push(buffer.len());
128 buffer
130 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
131 buffer.extend_from_slice(value);
132 }
133 Ok::<_, VortexError>(())
134 })?;
135 (buffer.freeze(), value_byte_indices)
136 }
137 };
138 Ok(buffer_and_value_byte_indices)
139}
140
141fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
142 let mut res = BufferMut::<BinaryView>::empty();
143 let mut offset = 0;
144 while offset < buffer.len() {
145 let str_len = ViewLen::from_le_bytes(
146 buffer
147 .get(offset..offset + size_of::<ViewLen>())
148 .vortex_expect("corrupted zstd length")
149 .try_into()
150 .vortex_expect("must fit ViewLen size"),
151 ) as usize;
152 offset += size_of::<ViewLen>();
153 let value = &buffer[offset..offset + str_len];
154 res.push(BinaryView::make_view(
155 value,
156 0,
157 u32::try_from(offset).vortex_expect("offset must fit in u32"),
158 ));
159 offset += str_len;
160 }
161 res.freeze()
162}
163
164impl ZstdArray {
165 pub fn new(
166 dictionary: Option<ByteBuffer>,
167 frames: Vec<ByteBuffer>,
168 dtype: DType,
169 metadata: ZstdMetadata,
170 n_rows: usize,
171 validity: Validity,
172 ) -> Self {
173 Self {
174 dictionary,
175 frames,
176 metadata,
177 dtype,
178 unsliced_validity: validity,
179 unsliced_n_rows: n_rows,
180 stats_set: Default::default(),
181 slice_start: 0,
182 slice_stop: n_rows,
183 }
184 }
185
186 fn compress_values(
187 value_bytes: &ByteBuffer,
188 frame_byte_starts: &[usize],
189 level: i32,
190 values_per_frame: usize,
191 n_values: usize,
192 ) -> VortexResult<Frames> {
193 let n_frames = frame_byte_starts.len();
194
195 let mut sample_sizes = Vec::with_capacity(n_frames);
197 for i in 0..n_frames {
198 let frame_byte_end = frame_byte_starts
199 .get(i + 1)
200 .copied()
201 .unwrap_or(value_bytes.len());
202 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
203 }
204 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
205
206 let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
207 (None, zstd::bulk::Compressor::new(level)?)
209 } else {
210 let max_dict_size = choose_max_dict_size(value_bytes.len());
212 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
213 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
214
215 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
216 (Some(ByteBuffer::from(dict)), compressor)
217 };
218
219 let mut frame_metas = vec![];
220 let mut frames = vec![];
221 for i in 0..n_frames {
222 let frame_byte_end = frame_byte_starts
223 .get(i + 1)
224 .copied()
225 .unwrap_or(value_bytes.len());
226 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
227 let compressed = compressor
228 .compress(uncompressed)
229 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
230 frame_metas.push(ZstdFrameMetadata {
231 uncompressed_size: uncompressed.len() as u64,
232 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
233 });
234 frames.push(ByteBuffer::from(compressed));
235 }
236
237 Ok(Frames {
238 dictionary,
239 frames,
240 frame_metas,
241 })
242 }
243
244 pub fn from_primitive(
245 parray: &PrimitiveArray,
246 level: i32,
247 values_per_frame: usize,
248 ) -> VortexResult<Self> {
249 let dtype = parray.dtype().clone();
250 let byte_width = parray.ptype().byte_width();
251
252 let values = collect_valid_primitive(parray)?;
254 let n_values = values.len();
255 let values_per_frame = if values_per_frame > 0 {
256 values_per_frame
257 } else {
258 n_values
259 };
260
261 let value_bytes = values.byte_buffer();
262 let frame_byte_starts = (0..n_values * byte_width)
263 .step_by(values_per_frame * byte_width)
264 .collect::<Vec<_>>();
265 let Frames {
266 dictionary,
267 frames,
268 frame_metas,
269 } = Self::compress_values(
270 value_bytes,
271 &frame_byte_starts,
272 level,
273 values_per_frame,
274 n_values,
275 )?;
276
277 let metadata = ZstdMetadata {
278 dictionary_size: dictionary
279 .as_ref()
280 .map_or(0, |dict| dict.len())
281 .try_into()?,
282 frames: frame_metas,
283 };
284
285 Ok(ZstdArray::new(
286 dictionary,
287 frames,
288 dtype,
289 metadata,
290 parray.len(),
291 parray.validity().clone(),
292 ))
293 }
294
295 pub fn from_var_bin_view(
296 vbv: &VarBinViewArray,
297 level: i32,
298 values_per_frame: usize,
299 ) -> VortexResult<Self> {
300 let dtype = vbv.dtype().clone();
306
307 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
309 let n_values = value_byte_indices.len();
310 let values_per_frame = if values_per_frame > 0 {
311 values_per_frame
312 } else {
313 n_values
314 };
315
316 let frame_byte_starts = (0..n_values)
317 .step_by(values_per_frame)
318 .map(|i| value_byte_indices[i])
319 .collect::<Vec<_>>();
320 let Frames {
321 dictionary,
322 frames,
323 frame_metas,
324 } = Self::compress_values(
325 &value_bytes,
326 &frame_byte_starts,
327 level,
328 values_per_frame,
329 n_values,
330 )?;
331
332 let metadata = ZstdMetadata {
333 dictionary_size: dictionary
334 .as_ref()
335 .map_or(0, |dict| dict.len())
336 .try_into()?,
337 frames: frame_metas,
338 };
339 Ok(ZstdArray::new(
340 dictionary,
341 frames,
342 dtype,
343 metadata,
344 vbv.len(),
345 vbv.validity().clone(),
346 ))
347 }
348
349 pub fn from_canonical(
350 canonical: &Canonical,
351 level: i32,
352 values_per_frame: usize,
353 ) -> VortexResult<Option<Self>> {
354 match canonical {
355 Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
356 parray,
357 level,
358 values_per_frame,
359 )?)),
360 Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
361 vbv,
362 level,
363 values_per_frame,
364 )?)),
365 _ => Ok(None),
366 }
367 }
368
369 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
370 Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
371 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
372 }
373
374 fn byte_width(&self) -> usize {
375 if self.dtype.is_primitive() {
376 self.dtype.as_ptype().byte_width()
377 } else {
378 1
379 }
380 }
381
382 pub fn decompress(&self) -> ArrayRef {
383 let byte_width = self.byte_width();
386 let slice_n_rows = self.slice_stop - self.slice_start;
387 let slice_value_indices = self
388 .unsliced_validity
389 .to_mask(self.unsliced_n_rows)
390 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
391
392 let slice_value_idx_start = slice_value_indices[0];
393 let slice_value_idx_stop = slice_value_indices[1];
394
395 let mut frames_to_decompress = vec![];
396 let mut value_idx_start = 0;
397 let mut uncompressed_size_to_decompress = 0;
398 let mut n_skipped_values = 0;
399 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
400 if value_idx_start >= slice_value_idx_stop {
401 break;
402 }
403
404 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
405 .vortex_expect("Uncompressed size must fit in usize");
406 let frame_n_values = if frame_meta.n_values == 0 {
407 frame_uncompressed_size / byte_width
409 } else {
410 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
411 };
412
413 let value_idx_stop = value_idx_start + frame_n_values;
414 if value_idx_stop > slice_value_idx_start {
415 frames_to_decompress.push(frame);
417 uncompressed_size_to_decompress += frame_uncompressed_size;
418 } else {
419 n_skipped_values += frame_n_values;
420 }
421 value_idx_start = value_idx_stop;
422 }
423
424 let mut decompressor = if let Some(dictionary) = &self.dictionary {
426 zstd::bulk::Decompressor::with_dictionary(dictionary)
427 } else {
428 zstd::bulk::Decompressor::new()
429 }
430 .vortex_expect("Decompressor encountered io error");
431 let mut decompressed = ByteBufferMut::with_capacity_aligned(
432 uncompressed_size_to_decompress,
433 Alignment::new(byte_width),
434 );
435 unsafe {
436 decompressed.set_len(uncompressed_size_to_decompress);
439 }
440 let mut uncompressed_start = 0;
441 for frame in frames_to_decompress {
442 let uncompressed_written = decompressor
443 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
444 .vortex_expect("error while decompressing zstd array");
445 uncompressed_start += uncompressed_written;
446 }
447 if uncompressed_start != uncompressed_size_to_decompress {
448 vortex_panic!(
449 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
450 uncompressed_size_to_decompress,
451 uncompressed_start
452 );
453 }
454
455 let decompressed = decompressed.freeze();
456 let slice_validity = self
458 .unsliced_validity
459 .slice(self.slice_start..self.slice_stop);
460
461 match &self.dtype {
462 DType::Primitive(..) => {
463 let slice_values_buffer = decompressed.slice(
464 (slice_value_idx_start - n_skipped_values) * byte_width
465 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
466 );
467 let primitive = PrimitiveArray::from_values_byte_buffer(
468 slice_values_buffer,
469 self.dtype.as_ptype(),
470 slice_validity,
471 slice_n_rows,
472 );
473
474 primitive.into_array()
475 }
476 DType::Binary(_) | DType::Utf8(_) => {
477 match slice_validity.to_mask(slice_n_rows).indices() {
478 AllOr::All => {
479 let valid_views = reconstruct_views(decompressed.clone()).slice(
483 slice_value_idx_start - n_skipped_values
484 ..slice_value_idx_stop - n_skipped_values,
485 );
486
487 unsafe {
489 VarBinViewArray::new_unchecked(
490 valid_views,
491 Arc::from([decompressed]),
492 self.dtype.clone(),
493 slice_validity,
494 )
495 }
496 .into_array()
497 }
498 AllOr::None => {
499 ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows)
500 .into_array()
501 }
502 AllOr::Some(valid_indices) => {
503 let valid_views = reconstruct_views(decompressed.clone()).slice(
507 slice_value_idx_start - n_skipped_values
508 ..slice_value_idx_stop - n_skipped_values,
509 );
510
511 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
512 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
513 views[*index] = view
514 }
515
516 unsafe {
518 VarBinViewArray::new_unchecked(
519 views.freeze(),
520 Arc::from([decompressed]),
521 self.dtype.clone(),
522 slice_validity,
523 )
524 }
525 .into_array()
526 }
527 }
528 }
529 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
530 }
531 }
532
533 pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
534 ZstdArray {
535 slice_start: self.slice_start + start,
536 slice_stop: self.slice_start + stop,
537 stats_set: Default::default(),
538 ..self.clone()
539 }
540 }
541
542 pub(crate) fn dtype(&self) -> &DType {
543 &self.dtype
544 }
545
546 pub(crate) fn slice_start(&self) -> usize {
547 self.slice_start
548 }
549
550 pub(crate) fn slice_stop(&self) -> usize {
551 self.slice_stop
552 }
553
554 pub(crate) fn unsliced_n_rows(&self) -> usize {
555 self.unsliced_n_rows
556 }
557}
558
559impl ValiditySliceHelper for ZstdArray {
560 fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
561 (&self.unsliced_validity, self.slice_start, self.slice_stop)
562 }
563}
564
565impl ArrayVTable<ZstdVTable> for ZstdVTable {
566 fn len(array: &ZstdArray) -> usize {
567 array.slice_stop - array.slice_start
568 }
569
570 fn dtype(array: &ZstdArray) -> &DType {
571 &array.dtype
572 }
573
574 fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
575 array.stats_set.to_ref(array.as_ref())
576 }
577
578 fn array_hash<H: std::hash::Hasher>(array: &ZstdArray, state: &mut H, precision: Precision) {
579 match &array.dictionary {
580 Some(dict) => {
581 true.hash(state);
582 dict.array_hash(state, precision);
583 }
584 None => {
585 false.hash(state);
586 }
587 }
588 for frame in &array.frames {
589 frame.array_hash(state, precision);
590 }
591 array.dtype.hash(state);
592 array.unsliced_validity.array_hash(state, precision);
593 array.unsliced_n_rows.hash(state);
594 array.slice_start.hash(state);
595 array.slice_stop.hash(state);
596 }
597
598 fn array_eq(array: &ZstdArray, other: &ZstdArray, precision: Precision) -> bool {
599 if !match (&array.dictionary, &other.dictionary) {
600 (Some(d1), Some(d2)) => d1.array_eq(d2, precision),
601 (None, None) => true,
602 _ => false,
603 } {
604 return false;
605 }
606 if array.frames.len() != other.frames.len() {
607 return false;
608 }
609 for (a, b) in array.frames.iter().zip(&other.frames) {
610 if !a.array_eq(b, precision) {
611 return false;
612 }
613 }
614 array.dtype == other.dtype
615 && array
616 .unsliced_validity
617 .array_eq(&other.unsliced_validity, precision)
618 && array.unsliced_n_rows == other.unsliced_n_rows
619 && array.slice_start == other.slice_start
620 && array.slice_stop == other.slice_stop
621 }
622}
623
624impl CanonicalVTable<ZstdVTable> for ZstdVTable {
625 fn canonicalize(array: &ZstdArray) -> Canonical {
626 array.decompress().to_canonical()
627 }
628}
629
630impl OperationsVTable<ZstdVTable> for ZstdVTable {
631 fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
632 array._slice(range.start, range.end).into_array()
633 }
634
635 fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
636 array._slice(index, index + 1).decompress().scalar_at(0)
637 }
638}