1use std::fmt::Debug;
5use std::ops::Range;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_array::accessor::ArrayAccessor;
10use vortex_array::arrays::binary_view::BinaryView;
11use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinViewArray};
12use vortex_array::compute::filter;
13use vortex_array::stats::{ArrayStats, StatsSetRef};
14use vortex_array::validity::Validity;
15use vortex_array::vtable::{
16 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
17 ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
18};
19use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
20use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
21use vortex_dtype::DType;
22use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
23use vortex_mask::AllOr;
24use vortex_scalar::Scalar;
25
26use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
27
28const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
30type ViewLen = u32;
31
32vtable!(Zstd);
51
52impl VTable for ZstdVTable {
53 type Array = ZstdArray;
54 type Encoding = ZstdEncoding;
55
56 type ArrayVTable = Self;
57 type CanonicalVTable = Self;
58 type OperationsVTable = Self;
59 type ValidityVTable = ValidityVTableFromValiditySliceHelper;
60 type VisitorVTable = Self;
61 type ComputeVTable = NotSupported;
62 type EncodeVTable = Self;
63 type SerdeVTable = Self;
64 type PipelineVTable = NotSupported;
65
66 fn id(_encoding: &Self::Encoding) -> EncodingId {
67 EncodingId::new_ref("vortex.zstd")
68 }
69
70 fn encoding(_array: &Self::Array) -> EncodingRef {
71 EncodingRef::new_ref(ZstdEncoding.as_ref())
72 }
73}
74
75#[derive(Clone, Debug)]
76pub struct ZstdEncoding;
77
78#[derive(Clone, Debug)]
79pub struct ZstdArray {
80 pub(crate) dictionary: Option<ByteBuffer>,
81 pub(crate) frames: Vec<ByteBuffer>,
82 pub(crate) metadata: ZstdMetadata,
83 dtype: DType,
84 pub(crate) unsliced_validity: Validity,
85 unsliced_n_rows: usize,
86 stats_set: ArrayStats,
87 slice_start: usize,
88 slice_stop: usize,
89}
90
91struct Frames {
92 dictionary: Option<ByteBuffer>,
93 frames: Vec<ByteBuffer>,
94 frame_metas: Vec<ZstdFrameMetadata>,
95}
96
97fn choose_max_dict_size(uncompressed_size: usize) -> usize {
98 (uncompressed_size / 100).clamp(256, 100 * 1024)
103}
104
105fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
106 let mask = parray.validity_mask();
107 Ok(filter(&parray.to_array(), &mask)?.to_primitive())
108}
109
110fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
111 let mask = vbv.validity_mask();
112 let buffer_and_value_byte_indices = match mask.boolean_buffer() {
113 AllOr::None => (Buffer::empty(), Vec::new()),
114 _ => {
115 let mut buffer = BufferMut::with_capacity(
116 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
117 + mask.true_count() * size_of::<ViewLen>(),
118 );
119 let mut value_byte_indices = Vec::new();
120 vbv.with_iterator(|iterator| {
121 for value in iterator.flatten() {
123 value_byte_indices.push(buffer.len());
124 buffer
126 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
127 buffer.extend_from_slice(value);
128 }
129 Ok::<_, VortexError>(())
130 })??;
131 (buffer.freeze(), value_byte_indices)
132 }
133 };
134 Ok(buffer_and_value_byte_indices)
135}
136
137fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
138 let mut res = BufferMut::<BinaryView>::empty();
139 let mut offset = 0;
140 while offset < buffer.len() {
141 let str_len = ViewLen::from_le_bytes(
142 buffer
143 .get(offset..offset + size_of::<ViewLen>())
144 .vortex_expect("corrupted zstd length")
145 .try_into()
146 .vortex_expect("must fit ViewLen size"),
147 ) as usize;
148 offset += size_of::<ViewLen>();
149 let value = &buffer[offset..offset + str_len];
150 res.push(BinaryView::make_view(
151 value,
152 0,
153 u32::try_from(offset).vortex_expect("offset must fit in u32"),
154 ));
155 offset += str_len;
156 }
157 res.freeze()
158}
159
160impl ZstdArray {
161 pub fn new(
162 dictionary: Option<ByteBuffer>,
163 frames: Vec<ByteBuffer>,
164 dtype: DType,
165 metadata: ZstdMetadata,
166 n_rows: usize,
167 validity: Validity,
168 ) -> Self {
169 Self {
170 dictionary,
171 frames,
172 metadata,
173 dtype,
174 unsliced_validity: validity,
175 unsliced_n_rows: n_rows,
176 stats_set: Default::default(),
177 slice_start: 0,
178 slice_stop: n_rows,
179 }
180 }
181
182 fn compress_values(
183 value_bytes: &ByteBuffer,
184 frame_byte_starts: &[usize],
185 level: i32,
186 values_per_frame: usize,
187 n_values: usize,
188 ) -> VortexResult<Frames> {
189 let n_frames = frame_byte_starts.len();
190
191 let mut sample_sizes = Vec::with_capacity(n_frames);
193 for i in 0..n_frames {
194 let frame_byte_end = frame_byte_starts
195 .get(i + 1)
196 .copied()
197 .unwrap_or(value_bytes.len());
198 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
199 }
200 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
201
202 let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
203 (None, zstd::bulk::Compressor::new(level)?)
205 } else {
206 let max_dict_size = choose_max_dict_size(value_bytes.len());
208 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
209 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
210
211 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
212 (Some(ByteBuffer::from(dict)), compressor)
213 };
214
215 let mut frame_metas = vec![];
216 let mut frames = vec![];
217 for i in 0..n_frames {
218 let frame_byte_end = frame_byte_starts
219 .get(i + 1)
220 .copied()
221 .unwrap_or(value_bytes.len());
222 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
223 let compressed = compressor
224 .compress(uncompressed)
225 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
226 frame_metas.push(ZstdFrameMetadata {
227 uncompressed_size: uncompressed.len() as u64,
228 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
229 });
230 frames.push(ByteBuffer::from(compressed));
231 }
232
233 Ok(Frames {
234 dictionary,
235 frames,
236 frame_metas,
237 })
238 }
239
240 pub fn from_primitive(
241 parray: &PrimitiveArray,
242 level: i32,
243 values_per_frame: usize,
244 ) -> VortexResult<Self> {
245 let dtype = parray.dtype().clone();
246 let byte_width = parray.ptype().byte_width();
247
248 let values = collect_valid_primitive(parray)?;
250 let n_values = values.len();
251 let values_per_frame = if values_per_frame > 0 {
252 values_per_frame
253 } else {
254 n_values
255 };
256
257 let value_bytes = values.byte_buffer();
258 let frame_byte_starts = (0..n_values * byte_width)
259 .step_by(values_per_frame * byte_width)
260 .collect::<Vec<_>>();
261 let Frames {
262 dictionary,
263 frames,
264 frame_metas,
265 } = Self::compress_values(
266 value_bytes,
267 &frame_byte_starts,
268 level,
269 values_per_frame,
270 n_values,
271 )?;
272
273 let metadata = ZstdMetadata {
274 dictionary_size: dictionary
275 .as_ref()
276 .map_or(0, |dict| dict.len())
277 .try_into()?,
278 frames: frame_metas,
279 };
280
281 Ok(ZstdArray::new(
282 dictionary,
283 frames,
284 dtype,
285 metadata,
286 parray.len(),
287 parray.validity().clone(),
288 ))
289 }
290
291 pub fn from_var_bin_view(
292 vbv: &VarBinViewArray,
293 level: i32,
294 values_per_frame: usize,
295 ) -> VortexResult<Self> {
296 let dtype = vbv.dtype().clone();
302
303 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
305 let n_values = value_byte_indices.len();
306 let values_per_frame = if values_per_frame > 0 {
307 values_per_frame
308 } else {
309 n_values
310 };
311
312 let frame_byte_starts = (0..n_values)
313 .step_by(values_per_frame)
314 .map(|i| value_byte_indices[i])
315 .collect::<Vec<_>>();
316 let Frames {
317 dictionary,
318 frames,
319 frame_metas,
320 } = Self::compress_values(
321 &value_bytes,
322 &frame_byte_starts,
323 level,
324 values_per_frame,
325 n_values,
326 )?;
327
328 let metadata = ZstdMetadata {
329 dictionary_size: dictionary
330 .as_ref()
331 .map_or(0, |dict| dict.len())
332 .try_into()?,
333 frames: frame_metas,
334 };
335 Ok(ZstdArray::new(
336 dictionary,
337 frames,
338 dtype,
339 metadata,
340 vbv.len(),
341 vbv.validity().clone(),
342 ))
343 }
344
345 pub fn from_canonical(
346 canonical: &Canonical,
347 level: i32,
348 values_per_frame: usize,
349 ) -> VortexResult<Option<Self>> {
350 match canonical {
351 Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
352 parray,
353 level,
354 values_per_frame,
355 )?)),
356 Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
357 vbv,
358 level,
359 values_per_frame,
360 )?)),
361 _ => Ok(None),
362 }
363 }
364
365 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
366 Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
367 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
368 }
369
370 fn byte_width(&self) -> usize {
371 if self.dtype.is_primitive() {
372 self.dtype.as_ptype().byte_width()
373 } else {
374 1
375 }
376 }
377
378 pub fn decompress(&self) -> ArrayRef {
379 let byte_width = self.byte_width();
382 let slice_n_rows = self.slice_stop - self.slice_start;
383 let slice_value_indices = self
384 .unsliced_validity
385 .to_mask(self.unsliced_n_rows)
386 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
387
388 let slice_value_idx_start = slice_value_indices[0];
389 let slice_value_idx_stop = slice_value_indices[1];
390
391 let mut frames_to_decompress = vec![];
392 let mut value_idx_start = 0;
393 let mut uncompressed_size_to_decompress = 0;
394 let mut n_skipped_values = 0;
395 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
396 if value_idx_start >= slice_value_idx_stop {
397 break;
398 }
399
400 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
401 .vortex_expect("Uncompressed size must fit in usize");
402 let frame_n_values = if frame_meta.n_values == 0 {
403 frame_uncompressed_size / byte_width
405 } else {
406 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
407 };
408
409 let value_idx_stop = value_idx_start + frame_n_values;
410 if value_idx_stop > slice_value_idx_start {
411 frames_to_decompress.push(frame);
413 uncompressed_size_to_decompress += frame_uncompressed_size;
414 } else {
415 n_skipped_values += frame_n_values;
416 }
417 value_idx_start = value_idx_stop;
418 }
419
420 let mut decompressor = if let Some(dictionary) = &self.dictionary {
422 zstd::bulk::Decompressor::with_dictionary(dictionary)
423 } else {
424 zstd::bulk::Decompressor::new()
425 }
426 .vortex_expect("Decompressor encountered io error");
427 let mut decompressed = ByteBufferMut::with_capacity_aligned(
428 uncompressed_size_to_decompress,
429 Alignment::new(byte_width),
430 );
431 unsafe {
432 decompressed.set_len(uncompressed_size_to_decompress);
435 }
436 let mut uncompressed_start = 0;
437 for frame in frames_to_decompress {
438 let uncompressed_written = decompressor
439 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
440 .vortex_expect("error while decompressing zstd array");
441 uncompressed_start += uncompressed_written;
442 }
443 if uncompressed_start != uncompressed_size_to_decompress {
444 vortex_panic!(
445 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
446 uncompressed_size_to_decompress,
447 uncompressed_start
448 );
449 }
450
451 let decompressed = decompressed.freeze();
452 let slice_validity = self
454 .unsliced_validity
455 .slice(self.slice_start..self.slice_stop);
456
457 match &self.dtype {
458 DType::Primitive(..) => {
459 let slice_values_buffer = decompressed.slice(
460 (slice_value_idx_start - n_skipped_values) * byte_width
461 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
462 );
463 let primitive = PrimitiveArray::from_values_byte_buffer(
464 slice_values_buffer,
465 self.dtype.as_ptype(),
466 slice_validity,
467 slice_n_rows,
468 );
469
470 primitive.into_array()
471 }
472 DType::Binary(_) | DType::Utf8(_) => {
473 match slice_validity.to_mask(slice_n_rows).indices() {
474 AllOr::All => {
475 let valid_views = reconstruct_views(decompressed.clone()).slice(
479 slice_value_idx_start - n_skipped_values
480 ..slice_value_idx_stop - n_skipped_values,
481 );
482
483 unsafe {
485 VarBinViewArray::new_unchecked(
486 valid_views,
487 Arc::from([decompressed]),
488 self.dtype.clone(),
489 slice_validity,
490 )
491 }
492 .into_array()
493 }
494 AllOr::None => {
495 ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows)
496 .into_array()
497 }
498 AllOr::Some(valid_indices) => {
499 let valid_views = reconstruct_views(decompressed.clone()).slice(
503 slice_value_idx_start - n_skipped_values
504 ..slice_value_idx_stop - n_skipped_values,
505 );
506
507 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
508 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
509 views[*index] = view
510 }
511
512 unsafe {
514 VarBinViewArray::new_unchecked(
515 views.freeze(),
516 Arc::from([decompressed]),
517 self.dtype.clone(),
518 slice_validity,
519 )
520 }
521 .into_array()
522 }
523 }
524 }
525 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
526 }
527 }
528
529 pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
530 ZstdArray {
531 slice_start: self.slice_start + start,
532 slice_stop: self.slice_start + stop,
533 stats_set: Default::default(),
534 ..self.clone()
535 }
536 }
537
538 pub(crate) fn dtype(&self) -> &DType {
539 &self.dtype
540 }
541
542 pub(crate) fn slice_start(&self) -> usize {
543 self.slice_start
544 }
545
546 pub(crate) fn slice_stop(&self) -> usize {
547 self.slice_stop
548 }
549
550 pub(crate) fn unsliced_n_rows(&self) -> usize {
551 self.unsliced_n_rows
552 }
553}
554
555impl ValiditySliceHelper for ZstdArray {
556 fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
557 (&self.unsliced_validity, self.slice_start, self.slice_stop)
558 }
559}
560
561impl ArrayVTable<ZstdVTable> for ZstdVTable {
562 fn len(array: &ZstdArray) -> usize {
563 array.slice_stop - array.slice_start
564 }
565
566 fn dtype(array: &ZstdArray) -> &DType {
567 &array.dtype
568 }
569
570 fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
571 array.stats_set.to_ref(array.as_ref())
572 }
573}
574
575impl CanonicalVTable<ZstdVTable> for ZstdVTable {
576 fn canonicalize(array: &ZstdArray) -> Canonical {
577 array.decompress().to_canonical()
578 }
579}
580
581impl OperationsVTable<ZstdVTable> for ZstdVTable {
582 fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
583 array._slice(range.start, range.end).into_array()
584 }
585
586 fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
587 array._slice(index, index + 1).decompress().scalar_at(0)
588 }
589}