1use std::fmt::Debug;
5use std::ops::Range;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_array::accessor::ArrayAccessor;
10use vortex_array::arrays::{BinaryView, ConstantArray, PrimitiveArray, VarBinViewArray};
11use vortex_array::compute::filter;
12use vortex_array::stats::{ArrayStats, StatsSetRef};
13use vortex_array::validity::Validity;
14use vortex_array::vtable::{
15 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
16 ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
17};
18use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
19use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
20use vortex_dtype::DType;
21use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
22use vortex_mask::AllOr;
23use vortex_scalar::Scalar;
24
25use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
26
27const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
29type ViewLen = u32;
30
31vtable!(Zstd);
50
51impl VTable for ZstdVTable {
52 type Array = ZstdArray;
53 type Encoding = ZstdEncoding;
54
55 type ArrayVTable = Self;
56 type CanonicalVTable = Self;
57 type OperationsVTable = Self;
58 type ValidityVTable = ValidityVTableFromValiditySliceHelper;
59 type VisitorVTable = Self;
60 type ComputeVTable = NotSupported;
61 type EncodeVTable = Self;
62 type SerdeVTable = Self;
63 type PipelineVTable = NotSupported;
64
65 fn id(_encoding: &Self::Encoding) -> EncodingId {
66 EncodingId::new_ref("vortex.zstd")
67 }
68
69 fn encoding(_array: &Self::Array) -> EncodingRef {
70 EncodingRef::new_ref(ZstdEncoding.as_ref())
71 }
72}
73
74#[derive(Clone, Debug)]
75pub struct ZstdEncoding;
76
77#[derive(Clone, Debug)]
78pub struct ZstdArray {
79 pub(crate) dictionary: Option<ByteBuffer>,
80 pub(crate) frames: Vec<ByteBuffer>,
81 pub(crate) metadata: ZstdMetadata,
82 dtype: DType,
83 pub(crate) unsliced_validity: Validity,
84 unsliced_n_rows: usize,
85 stats_set: ArrayStats,
86 slice_start: usize,
87 slice_stop: usize,
88}
89
90struct Frames {
91 dictionary: Option<ByteBuffer>,
92 frames: Vec<ByteBuffer>,
93 frame_metas: Vec<ZstdFrameMetadata>,
94}
95
96fn choose_max_dict_size(uncompressed_size: usize) -> usize {
97 (uncompressed_size / 100).clamp(256, 100 * 1024)
102}
103
104fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
105 let mask = parray.validity_mask();
106 Ok(filter(&parray.to_array(), &mask)?.to_primitive())
107}
108
109fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
110 let mask = vbv.validity_mask();
111 let buffer_and_value_byte_indices = match mask.boolean_buffer() {
112 AllOr::None => (Buffer::empty(), Vec::new()),
113 _ => {
114 let mut buffer = BufferMut::with_capacity(
115 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
116 + mask.true_count() * size_of::<ViewLen>(),
117 );
118 let mut value_byte_indices = Vec::new();
119 vbv.with_iterator(|iterator| {
120 for value in iterator.flatten() {
122 value_byte_indices.push(buffer.len());
123 buffer
125 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
126 buffer.extend_from_slice(value);
127 }
128 Ok::<_, VortexError>(())
129 })??;
130 (buffer.freeze(), value_byte_indices)
131 }
132 };
133 Ok(buffer_and_value_byte_indices)
134}
135
136fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
137 let mut res = BufferMut::<BinaryView>::empty();
138 let mut offset = 0;
139 while offset < buffer.len() {
140 let str_len = ViewLen::from_le_bytes(
141 buffer
142 .get(offset..offset + size_of::<ViewLen>())
143 .vortex_expect("corrupted zstd length")
144 .try_into()
145 .vortex_expect("must fit ViewLen size"),
146 ) as usize;
147 offset += size_of::<ViewLen>();
148 let value = &buffer[offset..offset + str_len];
149 res.push(BinaryView::make_view(
150 value,
151 0,
152 u32::try_from(offset).vortex_expect("offset must fit in u32"),
153 ));
154 offset += str_len;
155 }
156 res.freeze()
157}
158
159impl ZstdArray {
160 pub fn new(
161 dictionary: Option<ByteBuffer>,
162 frames: Vec<ByteBuffer>,
163 dtype: DType,
164 metadata: ZstdMetadata,
165 n_rows: usize,
166 validity: Validity,
167 ) -> Self {
168 Self {
169 dictionary,
170 frames,
171 metadata,
172 dtype,
173 unsliced_validity: validity,
174 unsliced_n_rows: n_rows,
175 stats_set: Default::default(),
176 slice_start: 0,
177 slice_stop: n_rows,
178 }
179 }
180
181 fn compress_values(
182 value_bytes: &ByteBuffer,
183 frame_byte_starts: &[usize],
184 level: i32,
185 values_per_frame: usize,
186 n_values: usize,
187 ) -> VortexResult<Frames> {
188 let n_frames = frame_byte_starts.len();
189
190 let mut sample_sizes = Vec::with_capacity(n_frames);
192 for i in 0..n_frames {
193 let frame_byte_end = frame_byte_starts
194 .get(i + 1)
195 .copied()
196 .unwrap_or(value_bytes.len());
197 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
198 }
199 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
200
201 let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
202 (None, zstd::bulk::Compressor::new(level)?)
204 } else {
205 let max_dict_size = choose_max_dict_size(value_bytes.len());
207 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
208 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
209
210 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
211 (Some(ByteBuffer::from(dict)), compressor)
212 };
213
214 let mut frame_metas = vec![];
215 let mut frames = vec![];
216 for i in 0..n_frames {
217 let frame_byte_end = frame_byte_starts
218 .get(i + 1)
219 .copied()
220 .unwrap_or(value_bytes.len());
221 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
222 let compressed = compressor
223 .compress(uncompressed)
224 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
225 frame_metas.push(ZstdFrameMetadata {
226 uncompressed_size: uncompressed.len() as u64,
227 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
228 });
229 frames.push(ByteBuffer::from(compressed));
230 }
231
232 Ok(Frames {
233 dictionary,
234 frames,
235 frame_metas,
236 })
237 }
238
239 pub fn from_primitive(
240 parray: &PrimitiveArray,
241 level: i32,
242 values_per_frame: usize,
243 ) -> VortexResult<Self> {
244 let dtype = parray.dtype().clone();
245 let byte_width = parray.ptype().byte_width();
246
247 let values = collect_valid_primitive(parray)?;
249 let n_values = values.len();
250 let values_per_frame = if values_per_frame > 0 {
251 values_per_frame
252 } else {
253 n_values
254 };
255
256 let value_bytes = values.byte_buffer();
257 let frame_byte_starts = (0..n_values * byte_width)
258 .step_by(values_per_frame * byte_width)
259 .collect::<Vec<_>>();
260 let Frames {
261 dictionary,
262 frames,
263 frame_metas,
264 } = Self::compress_values(
265 value_bytes,
266 &frame_byte_starts,
267 level,
268 values_per_frame,
269 n_values,
270 )?;
271
272 let metadata = ZstdMetadata {
273 dictionary_size: dictionary
274 .as_ref()
275 .map_or(0, |dict| dict.len())
276 .try_into()?,
277 frames: frame_metas,
278 };
279
280 Ok(ZstdArray::new(
281 dictionary,
282 frames,
283 dtype,
284 metadata,
285 parray.len(),
286 parray.validity().clone(),
287 ))
288 }
289
290 pub fn from_var_bin_view(
291 vbv: &VarBinViewArray,
292 level: i32,
293 values_per_frame: usize,
294 ) -> VortexResult<Self> {
295 let dtype = vbv.dtype().clone();
301
302 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
304 let n_values = value_byte_indices.len();
305 let values_per_frame = if values_per_frame > 0 {
306 values_per_frame
307 } else {
308 n_values
309 };
310
311 let frame_byte_starts = (0..n_values)
312 .step_by(values_per_frame)
313 .map(|i| value_byte_indices[i])
314 .collect::<Vec<_>>();
315 let Frames {
316 dictionary,
317 frames,
318 frame_metas,
319 } = Self::compress_values(
320 &value_bytes,
321 &frame_byte_starts,
322 level,
323 values_per_frame,
324 n_values,
325 )?;
326
327 let metadata = ZstdMetadata {
328 dictionary_size: dictionary
329 .as_ref()
330 .map_or(0, |dict| dict.len())
331 .try_into()?,
332 frames: frame_metas,
333 };
334 Ok(ZstdArray::new(
335 dictionary,
336 frames,
337 dtype,
338 metadata,
339 vbv.len(),
340 vbv.validity().clone(),
341 ))
342 }
343
344 pub fn from_canonical(
345 canonical: &Canonical,
346 level: i32,
347 values_per_frame: usize,
348 ) -> VortexResult<Option<Self>> {
349 match canonical {
350 Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
351 parray,
352 level,
353 values_per_frame,
354 )?)),
355 Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
356 vbv,
357 level,
358 values_per_frame,
359 )?)),
360 _ => Ok(None),
361 }
362 }
363
364 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
365 Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
366 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
367 }
368
369 fn byte_width(&self) -> usize {
370 if self.dtype.is_primitive() {
371 self.dtype.as_ptype().byte_width()
372 } else {
373 1
374 }
375 }
376
377 pub fn decompress(&self) -> ArrayRef {
378 let byte_width = self.byte_width();
381 let slice_n_rows = self.slice_stop - self.slice_start;
382 let slice_value_indices = self
383 .unsliced_validity
384 .to_mask(self.unsliced_n_rows)
385 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
386
387 let slice_value_idx_start = slice_value_indices[0];
388 let slice_value_idx_stop = slice_value_indices[1];
389
390 let mut frames_to_decompress = vec![];
391 let mut value_idx_start = 0;
392 let mut uncompressed_size_to_decompress = 0;
393 let mut n_skipped_values = 0;
394 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
395 if value_idx_start >= slice_value_idx_stop {
396 break;
397 }
398
399 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
400 .vortex_expect("Uncompressed size must fit in usize");
401 let frame_n_values = if frame_meta.n_values == 0 {
402 frame_uncompressed_size / byte_width
404 } else {
405 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
406 };
407
408 let value_idx_stop = value_idx_start + frame_n_values;
409 if value_idx_stop > slice_value_idx_start {
410 frames_to_decompress.push(frame);
412 uncompressed_size_to_decompress += frame_uncompressed_size;
413 } else {
414 n_skipped_values += frame_n_values;
415 }
416 value_idx_start = value_idx_stop;
417 }
418
419 let mut decompressor = if let Some(dictionary) = &self.dictionary {
421 zstd::bulk::Decompressor::with_dictionary(dictionary)
422 } else {
423 zstd::bulk::Decompressor::new()
424 }
425 .vortex_expect("Decompressor encountered io error");
426 let mut decompressed = ByteBufferMut::with_capacity_aligned(
427 uncompressed_size_to_decompress,
428 Alignment::new(byte_width),
429 );
430 unsafe {
431 decompressed.set_len(uncompressed_size_to_decompress);
434 }
435 let mut uncompressed_start = 0;
436 for frame in frames_to_decompress {
437 let uncompressed_written = decompressor
438 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
439 .vortex_expect("error while decompressing zstd array");
440 uncompressed_start += uncompressed_written;
441 }
442 if uncompressed_start != uncompressed_size_to_decompress {
443 vortex_panic!(
444 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
445 uncompressed_size_to_decompress,
446 uncompressed_start
447 );
448 }
449
450 let decompressed = decompressed.freeze();
451 let slice_validity = self
453 .unsliced_validity
454 .slice(self.slice_start..self.slice_stop);
455
456 match &self.dtype {
457 DType::Primitive(..) => {
458 let slice_values_buffer = decompressed.slice(
459 (slice_value_idx_start - n_skipped_values) * byte_width
460 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
461 );
462 let primitive = PrimitiveArray::from_values_byte_buffer(
463 slice_values_buffer,
464 self.dtype.as_ptype(),
465 slice_validity,
466 slice_n_rows,
467 );
468
469 primitive.into_array()
470 }
471 DType::Binary(_) | DType::Utf8(_) => {
472 match slice_validity.to_mask(slice_n_rows).indices() {
473 AllOr::All => {
474 let valid_views = reconstruct_views(decompressed.clone()).slice(
478 slice_value_idx_start - n_skipped_values
479 ..slice_value_idx_stop - n_skipped_values,
480 );
481
482 unsafe {
484 VarBinViewArray::new_unchecked(
485 valid_views,
486 Arc::from([decompressed]),
487 self.dtype.clone(),
488 slice_validity,
489 )
490 }
491 .into_array()
492 }
493 AllOr::None => {
494 ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows)
495 .into_array()
496 }
497 AllOr::Some(valid_indices) => {
498 let valid_views = reconstruct_views(decompressed.clone()).slice(
502 slice_value_idx_start - n_skipped_values
503 ..slice_value_idx_stop - n_skipped_values,
504 );
505
506 let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
507 for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
508 views[*index] = view
509 }
510
511 unsafe {
513 VarBinViewArray::new_unchecked(
514 views.freeze(),
515 Arc::from([decompressed]),
516 self.dtype.clone(),
517 slice_validity,
518 )
519 }
520 .into_array()
521 }
522 }
523 }
524 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
525 }
526 }
527
528 pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
529 ZstdArray {
530 slice_start: self.slice_start + start,
531 slice_stop: self.slice_start + stop,
532 stats_set: Default::default(),
533 ..self.clone()
534 }
535 }
536
537 pub(crate) fn dtype(&self) -> &DType {
538 &self.dtype
539 }
540
541 pub(crate) fn slice_start(&self) -> usize {
542 self.slice_start
543 }
544
545 pub(crate) fn slice_stop(&self) -> usize {
546 self.slice_stop
547 }
548
549 pub(crate) fn unsliced_n_rows(&self) -> usize {
550 self.unsliced_n_rows
551 }
552}
553
554impl ValiditySliceHelper for ZstdArray {
555 fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
556 (&self.unsliced_validity, self.slice_start, self.slice_stop)
557 }
558}
559
560impl ArrayVTable<ZstdVTable> for ZstdVTable {
561 fn len(array: &ZstdArray) -> usize {
562 array.slice_stop - array.slice_start
563 }
564
565 fn dtype(array: &ZstdArray) -> &DType {
566 &array.dtype
567 }
568
569 fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
570 array.stats_set.to_ref(array.as_ref())
571 }
572}
573
574impl CanonicalVTable<ZstdVTable> for ZstdVTable {
575 fn canonicalize(array: &ZstdArray) -> Canonical {
576 array.decompress().to_canonical()
577 }
578}
579
580impl OperationsVTable<ZstdVTable> for ZstdVTable {
581 fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
582 array._slice(range.start, range.end).into_array()
583 }
584
585 fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
586 array._slice(index, index + 1).decompress().scalar_at(0)
587 }
588}