1use std::fmt::Debug;
5use std::sync::Arc;
6
7use vortex_array::accessor::ArrayAccessor;
8use vortex_array::arrays::{BinaryView, PrimitiveArray, VarBinViewArray};
9use vortex_array::compute::filter;
10use vortex_array::stats::{ArrayStats, StatsSetRef};
11use vortex_array::validity::Validity;
12use vortex_array::vtable::{
13 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
14 ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
15};
16use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
17use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
18use vortex_dtype::DType;
19use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
20use vortex_mask::AllOr;
21use vortex_scalar::Scalar;
22
23use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
24
25const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
27type ViewLen = u32;
28
29vtable!(Zstd);
48
49impl VTable for ZstdVTable {
50 type Array = ZstdArray;
51 type Encoding = ZstdEncoding;
52
53 type ArrayVTable = Self;
54 type CanonicalVTable = Self;
55 type OperationsVTable = Self;
56 type ValidityVTable = ValidityVTableFromValiditySliceHelper;
57 type VisitorVTable = Self;
58 type ComputeVTable = NotSupported;
59 type EncodeVTable = Self;
60 type SerdeVTable = Self;
61 type PipelineVTable = NotSupported;
62
63 fn id(_encoding: &Self::Encoding) -> EncodingId {
64 EncodingId::new_ref("vortex.zstd")
65 }
66
67 fn encoding(_array: &Self::Array) -> EncodingRef {
68 EncodingRef::new_ref(ZstdEncoding.as_ref())
69 }
70}
71
72#[derive(Clone, Debug)]
73pub struct ZstdEncoding;
74
75#[derive(Clone, Debug)]
76pub struct ZstdArray {
77 pub(crate) dictionary: Option<ByteBuffer>,
78 pub(crate) frames: Vec<ByteBuffer>,
79 pub(crate) metadata: ZstdMetadata,
80 dtype: DType,
81 pub(crate) unsliced_validity: Validity,
82 unsliced_n_rows: usize,
83 stats_set: ArrayStats,
84 slice_start: usize,
85 slice_stop: usize,
86}
87
88struct Frames {
89 dictionary: Option<ByteBuffer>,
90 frames: Vec<ByteBuffer>,
91 frame_metas: Vec<ZstdFrameMetadata>,
92}
93
94fn choose_max_dict_size(uncompressed_size: usize) -> usize {
95 (uncompressed_size / 100).clamp(256, 100 * 1024)
100}
101
102fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
103 let mask = parray.validity_mask()?;
104 filter(&parray.to_array(), &mask)?.to_primitive()
105}
106
107fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
108 let mask = vbv.validity_mask()?;
109 let buffer_and_value_byte_indices = match mask.boolean_buffer() {
110 AllOr::None => (Buffer::empty(), Vec::new()),
111 _ => {
112 let mut buffer = BufferMut::with_capacity(
113 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
114 + mask.true_count() * size_of::<ViewLen>(),
115 );
116 let mut value_byte_indices = Vec::new();
117 vbv.with_iterator(|iterator| {
118 for value in iterator.flatten() {
120 value_byte_indices.push(buffer.len());
121 buffer
123 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
124 buffer.extend_from_slice(value);
125 }
126 Ok::<_, VortexError>(())
127 })??;
128 (buffer.freeze(), value_byte_indices)
129 }
130 };
131 Ok(buffer_and_value_byte_indices)
132}
133
134fn reconstruct_views(buffer: ByteBuffer) -> VortexResult<Buffer<BinaryView>> {
135 let mut res = BufferMut::<BinaryView>::empty();
136 let mut offset = 0;
137 while offset < buffer.len() {
138 let str_len = ViewLen::from_le_bytes(
139 buffer
140 .get(offset..offset + size_of::<ViewLen>())
141 .ok_or_else(|| vortex_err!("Zstd buffer for VarBinView was corrupt"))?
142 .try_into()?,
143 ) as usize;
144 offset += size_of::<ViewLen>();
145 let value = &buffer[offset..offset + str_len];
146 res.push(BinaryView::make_view(value, 0, u32::try_from(offset)?));
147 offset += str_len;
148 }
149 Ok(res.freeze())
150}
151
152impl ZstdArray {
153 pub fn new(
154 dictionary: Option<ByteBuffer>,
155 frames: Vec<ByteBuffer>,
156 dtype: DType,
157 metadata: ZstdMetadata,
158 n_rows: usize,
159 validity: Validity,
160 ) -> Self {
161 Self {
162 dictionary,
163 frames,
164 metadata,
165 dtype,
166 unsliced_validity: validity,
167 unsliced_n_rows: n_rows,
168 stats_set: Default::default(),
169 slice_start: 0,
170 slice_stop: n_rows,
171 }
172 }
173
174 fn compress_values(
175 value_bytes: &ByteBuffer,
176 frame_byte_starts: &[usize],
177 level: i32,
178 values_per_frame: usize,
179 n_values: usize,
180 ) -> VortexResult<Frames> {
181 let n_frames = frame_byte_starts.len();
182
183 let mut sample_sizes = Vec::with_capacity(n_frames);
185 for i in 0..n_frames {
186 let frame_byte_end = frame_byte_starts
187 .get(i + 1)
188 .copied()
189 .unwrap_or(value_bytes.len());
190 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
191 }
192 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
193
194 let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
195 (None, zstd::bulk::Compressor::new(level)?)
197 } else {
198 let max_dict_size = choose_max_dict_size(value_bytes.len());
200 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
201 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
202
203 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
204 (Some(ByteBuffer::from(dict)), compressor)
205 };
206
207 let mut frame_metas = vec![];
208 let mut frames = vec![];
209 for i in 0..n_frames {
210 let frame_byte_end = frame_byte_starts
211 .get(i + 1)
212 .copied()
213 .unwrap_or(value_bytes.len());
214 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
215 let compressed = compressor
216 .compress(uncompressed)
217 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
218 frame_metas.push(ZstdFrameMetadata {
219 uncompressed_size: uncompressed.len() as u64,
220 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
221 });
222 frames.push(ByteBuffer::from(compressed));
223 }
224
225 Ok(Frames {
226 dictionary,
227 frames,
228 frame_metas,
229 })
230 }
231
232 pub fn from_primitive(
233 parray: &PrimitiveArray,
234 level: i32,
235 values_per_frame: usize,
236 ) -> VortexResult<Self> {
237 let dtype = parray.dtype().clone();
238 let byte_width = parray.ptype().byte_width();
239
240 let values = collect_valid_primitive(parray)?;
242 let n_values = values.len();
243 let values_per_frame = if values_per_frame > 0 {
244 values_per_frame
245 } else {
246 n_values
247 };
248
249 let value_bytes = values.byte_buffer();
250 let frame_byte_starts = (0..n_values * byte_width)
251 .step_by(values_per_frame * byte_width)
252 .collect::<Vec<_>>();
253 let Frames {
254 dictionary,
255 frames,
256 frame_metas,
257 } = Self::compress_values(
258 value_bytes,
259 &frame_byte_starts,
260 level,
261 values_per_frame,
262 n_values,
263 )?;
264
265 let metadata = ZstdMetadata {
266 dictionary_size: dictionary
267 .as_ref()
268 .map_or(0, |dict| dict.len())
269 .try_into()?,
270 frames: frame_metas,
271 };
272
273 Ok(ZstdArray::new(
274 dictionary,
275 frames,
276 dtype,
277 metadata,
278 parray.len(),
279 parray.validity().clone(),
280 ))
281 }
282
283 pub fn from_var_bin_view(
284 vbv: &VarBinViewArray,
285 level: i32,
286 values_per_frame: usize,
287 ) -> VortexResult<Self> {
288 let dtype = vbv.dtype().clone();
294
295 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
297 let n_values = value_byte_indices.len();
298 let values_per_frame = if values_per_frame > 0 {
299 values_per_frame
300 } else {
301 n_values
302 };
303
304 let frame_byte_starts = (0..n_values)
305 .step_by(values_per_frame)
306 .map(|i| value_byte_indices[i])
307 .collect::<Vec<_>>();
308 let Frames {
309 dictionary,
310 frames,
311 frame_metas,
312 } = Self::compress_values(
313 &value_bytes,
314 &frame_byte_starts,
315 level,
316 values_per_frame,
317 n_values,
318 )?;
319
320 let metadata = ZstdMetadata {
321 dictionary_size: dictionary
322 .as_ref()
323 .map_or(0, |dict| dict.len())
324 .try_into()?,
325 frames: frame_metas,
326 };
327 Ok(ZstdArray::new(
328 dictionary,
329 frames,
330 dtype,
331 metadata,
332 vbv.len(),
333 vbv.validity().clone(),
334 ))
335 }
336
337 pub fn from_canonical(
338 canonical: &Canonical,
339 level: i32,
340 values_per_frame: usize,
341 ) -> VortexResult<Option<Self>> {
342 match canonical {
343 Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
344 parray,
345 level,
346 values_per_frame,
347 )?)),
348 Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
349 vbv,
350 level,
351 values_per_frame,
352 )?)),
353 _ => Ok(None),
354 }
355 }
356
357 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
358 Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
359 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
360 }
361
362 fn byte_width(&self) -> usize {
363 if self.dtype.is_primitive() {
364 self.dtype.as_ptype().byte_width()
365 } else {
366 1
367 }
368 }
369
370 pub fn decompress(&self) -> VortexResult<ArrayRef> {
371 let byte_width = self.byte_width();
374 let slice_n_rows = self.slice_stop - self.slice_start;
375 let slice_value_indices = self
376 .unsliced_validity
377 .to_mask(self.unsliced_n_rows)?
378 .valid_counts_for_indices(&[self.slice_start, self.slice_stop])?;
379
380 let slice_value_idx_start = slice_value_indices[0];
381 let slice_value_idx_stop = slice_value_indices[1];
382
383 let mut frames_to_decompress = vec![];
384 let mut value_idx_start = 0;
385 let mut uncompressed_size_to_decompress = 0;
386 let mut n_skipped_values = 0;
387 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
388 if value_idx_start >= slice_value_idx_stop {
389 break;
390 }
391
392 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)?;
393 let frame_n_values = if frame_meta.n_values == 0 {
394 frame_uncompressed_size / byte_width
396 } else {
397 usize::try_from(frame_meta.n_values)?
398 };
399
400 let value_idx_stop = value_idx_start + frame_n_values;
401 if value_idx_stop > slice_value_idx_start {
402 frames_to_decompress.push(frame);
404 uncompressed_size_to_decompress += frame_uncompressed_size;
405 } else {
406 n_skipped_values += frame_n_values;
407 }
408 value_idx_start = value_idx_stop;
409 }
410
411 let mut decompressor = if let Some(dictionary) = &self.dictionary {
413 zstd::bulk::Decompressor::with_dictionary(dictionary)
414 } else {
415 zstd::bulk::Decompressor::new()
416 }?;
417 let mut decompressed = ByteBufferMut::with_capacity_aligned(
418 uncompressed_size_to_decompress,
419 Alignment::new(byte_width),
420 );
421 unsafe {
422 decompressed.set_len(uncompressed_size_to_decompress);
425 }
426 let mut uncompressed_start = 0;
427 for frame in frames_to_decompress {
428 let uncompressed_written = decompressor
429 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
430 .map_err(|err| VortexError::from(err).with_context("while decompressing"))?;
431 uncompressed_start += uncompressed_written;
432 }
433 if uncompressed_start != uncompressed_size_to_decompress {
434 vortex_bail!(
435 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
436 uncompressed_size_to_decompress,
437 uncompressed_start
438 );
439 }
440
441 let decompressed = decompressed.freeze();
442 let slice_validity = self
444 .unsliced_validity
445 .slice(self.slice_start, self.slice_stop);
446
447 match &self.dtype {
448 DType::Primitive(..) => {
449 let slice_values_buffer = decompressed.slice(
450 (slice_value_idx_start - n_skipped_values) * byte_width
451 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
452 );
453 let primitive = PrimitiveArray::from_values_byte_buffer(
454 slice_values_buffer,
455 self.dtype.as_ptype(),
456 slice_validity,
457 slice_n_rows,
458 )?;
459
460 Ok(primitive.into_array())
461 }
462 DType::Binary(_) | DType::Utf8(_) => {
463 let views = reconstruct_views(decompressed.clone())?.slice(
467 slice_value_idx_start - n_skipped_values
468 ..slice_value_idx_stop - n_skipped_values,
469 );
470
471 let vbv = unsafe {
473 VarBinViewArray::new_unchecked(
474 views,
475 Arc::from([decompressed]),
476 self.dtype.clone(),
477 slice_validity,
478 )
479 };
480 Ok(vbv.into_array())
481 }
482 _ => Err(vortex_err!(
483 "Unsupported dtype for Zstd array: {:?}",
484 self.dtype
485 )),
486 }
487 }
488
489 pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
490 ZstdArray {
491 slice_start: self.slice_start + start,
492 slice_stop: self.slice_start + stop,
493 stats_set: Default::default(),
494 ..self.clone()
495 }
496 }
497
498 pub(crate) fn dtype(&self) -> &DType {
499 &self.dtype
500 }
501
502 pub(crate) fn slice_start(&self) -> usize {
503 self.slice_start
504 }
505
506 pub(crate) fn slice_stop(&self) -> usize {
507 self.slice_stop
508 }
509
510 pub(crate) fn unsliced_n_rows(&self) -> usize {
511 self.unsliced_n_rows
512 }
513}
514
515impl ValiditySliceHelper for ZstdArray {
516 fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
517 (&self.unsliced_validity, self.slice_start, self.slice_stop)
518 }
519}
520
521impl ArrayVTable<ZstdVTable> for ZstdVTable {
522 fn len(array: &ZstdArray) -> usize {
523 array.slice_stop - array.slice_start
524 }
525
526 fn dtype(array: &ZstdArray) -> &DType {
527 &array.dtype
528 }
529
530 fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
531 array.stats_set.to_ref(array.as_ref())
532 }
533}
534
535impl CanonicalVTable<ZstdVTable> for ZstdVTable {
536 fn canonicalize(array: &ZstdArray) -> VortexResult<Canonical> {
537 array.decompress()?.to_canonical()
538 }
539}
540
541impl OperationsVTable<ZstdVTable> for ZstdVTable {
542 fn slice(array: &ZstdArray, start: usize, stop: usize) -> ArrayRef {
543 array._slice(start, stop).into_array()
544 }
545
546 fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
547 array
548 ._slice(index, index + 1)
549 .decompress()
550 .vortex_expect("zstd decompress")
551 .scalar_at(0)
552 }
553}