1use std::fmt::Debug;
5use std::ops::Range;
6use std::sync::Arc;
7
8use vortex_array::accessor::ArrayAccessor;
9use vortex_array::arrays::{BinaryView, PrimitiveArray, VarBinViewArray};
10use vortex_array::compute::filter;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::validity::Validity;
13use vortex_array::vtable::{
14 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
15 ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
16};
17use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
18use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
19use vortex_dtype::DType;
20use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
21use vortex_mask::AllOr;
22use vortex_scalar::Scalar;
23
24use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
25
26const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
28type ViewLen = u32;
29
30vtable!(Zstd);
49
50impl VTable for ZstdVTable {
51 type Array = ZstdArray;
52 type Encoding = ZstdEncoding;
53
54 type ArrayVTable = Self;
55 type CanonicalVTable = Self;
56 type OperationsVTable = Self;
57 type ValidityVTable = ValidityVTableFromValiditySliceHelper;
58 type VisitorVTable = Self;
59 type ComputeVTable = NotSupported;
60 type EncodeVTable = Self;
61 type SerdeVTable = Self;
62 type PipelineVTable = NotSupported;
63
64 fn id(_encoding: &Self::Encoding) -> EncodingId {
65 EncodingId::new_ref("vortex.zstd")
66 }
67
68 fn encoding(_array: &Self::Array) -> EncodingRef {
69 EncodingRef::new_ref(ZstdEncoding.as_ref())
70 }
71}
72
73#[derive(Clone, Debug)]
74pub struct ZstdEncoding;
75
76#[derive(Clone, Debug)]
77pub struct ZstdArray {
78 pub(crate) dictionary: Option<ByteBuffer>,
79 pub(crate) frames: Vec<ByteBuffer>,
80 pub(crate) metadata: ZstdMetadata,
81 dtype: DType,
82 pub(crate) unsliced_validity: Validity,
83 unsliced_n_rows: usize,
84 stats_set: ArrayStats,
85 slice_start: usize,
86 slice_stop: usize,
87}
88
89struct Frames {
90 dictionary: Option<ByteBuffer>,
91 frames: Vec<ByteBuffer>,
92 frame_metas: Vec<ZstdFrameMetadata>,
93}
94
95fn choose_max_dict_size(uncompressed_size: usize) -> usize {
96 (uncompressed_size / 100).clamp(256, 100 * 1024)
101}
102
103fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
104 let mask = parray.validity_mask();
105 Ok(filter(&parray.to_array(), &mask)?.to_primitive())
106}
107
108fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
109 let mask = vbv.validity_mask();
110 let buffer_and_value_byte_indices = match mask.boolean_buffer() {
111 AllOr::None => (Buffer::empty(), Vec::new()),
112 _ => {
113 let mut buffer = BufferMut::with_capacity(
114 usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
115 + mask.true_count() * size_of::<ViewLen>(),
116 );
117 let mut value_byte_indices = Vec::new();
118 vbv.with_iterator(|iterator| {
119 for value in iterator.flatten() {
121 value_byte_indices.push(buffer.len());
122 buffer
124 .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
125 buffer.extend_from_slice(value);
126 }
127 Ok::<_, VortexError>(())
128 })??;
129 (buffer.freeze(), value_byte_indices)
130 }
131 };
132 Ok(buffer_and_value_byte_indices)
133}
134
135fn reconstruct_views(buffer: ByteBuffer) -> Buffer<BinaryView> {
136 let mut res = BufferMut::<BinaryView>::empty();
137 let mut offset = 0;
138 while offset < buffer.len() {
139 let str_len = ViewLen::from_le_bytes(
140 buffer
141 .get(offset..offset + size_of::<ViewLen>())
142 .vortex_expect("corrupted zstd length")
143 .try_into()
144 .vortex_expect("must fit ViewLen size"),
145 ) as usize;
146 offset += size_of::<ViewLen>();
147 let value = &buffer[offset..offset + str_len];
148 res.push(BinaryView::make_view(
149 value,
150 0,
151 u32::try_from(offset).vortex_expect("offset must fit in u32"),
152 ));
153 offset += str_len;
154 }
155 res.freeze()
156}
157
158impl ZstdArray {
159 pub fn new(
160 dictionary: Option<ByteBuffer>,
161 frames: Vec<ByteBuffer>,
162 dtype: DType,
163 metadata: ZstdMetadata,
164 n_rows: usize,
165 validity: Validity,
166 ) -> Self {
167 Self {
168 dictionary,
169 frames,
170 metadata,
171 dtype,
172 unsliced_validity: validity,
173 unsliced_n_rows: n_rows,
174 stats_set: Default::default(),
175 slice_start: 0,
176 slice_stop: n_rows,
177 }
178 }
179
180 fn compress_values(
181 value_bytes: &ByteBuffer,
182 frame_byte_starts: &[usize],
183 level: i32,
184 values_per_frame: usize,
185 n_values: usize,
186 ) -> VortexResult<Frames> {
187 let n_frames = frame_byte_starts.len();
188
189 let mut sample_sizes = Vec::with_capacity(n_frames);
191 for i in 0..n_frames {
192 let frame_byte_end = frame_byte_starts
193 .get(i + 1)
194 .copied()
195 .unwrap_or(value_bytes.len());
196 sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
197 }
198 debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
199
200 let (dictionary, mut compressor) = if sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY {
201 (None, zstd::bulk::Compressor::new(level)?)
203 } else {
204 let max_dict_size = choose_max_dict_size(value_bytes.len());
206 let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
207 .map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
208
209 let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
210 (Some(ByteBuffer::from(dict)), compressor)
211 };
212
213 let mut frame_metas = vec![];
214 let mut frames = vec![];
215 for i in 0..n_frames {
216 let frame_byte_end = frame_byte_starts
217 .get(i + 1)
218 .copied()
219 .unwrap_or(value_bytes.len());
220 let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
221 let compressed = compressor
222 .compress(uncompressed)
223 .map_err(|err| VortexError::from(err).with_context("while compressing"))?;
224 frame_metas.push(ZstdFrameMetadata {
225 uncompressed_size: uncompressed.len() as u64,
226 n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
227 });
228 frames.push(ByteBuffer::from(compressed));
229 }
230
231 Ok(Frames {
232 dictionary,
233 frames,
234 frame_metas,
235 })
236 }
237
238 pub fn from_primitive(
239 parray: &PrimitiveArray,
240 level: i32,
241 values_per_frame: usize,
242 ) -> VortexResult<Self> {
243 let dtype = parray.dtype().clone();
244 let byte_width = parray.ptype().byte_width();
245
246 let values = collect_valid_primitive(parray)?;
248 let n_values = values.len();
249 let values_per_frame = if values_per_frame > 0 {
250 values_per_frame
251 } else {
252 n_values
253 };
254
255 let value_bytes = values.byte_buffer();
256 let frame_byte_starts = (0..n_values * byte_width)
257 .step_by(values_per_frame * byte_width)
258 .collect::<Vec<_>>();
259 let Frames {
260 dictionary,
261 frames,
262 frame_metas,
263 } = Self::compress_values(
264 value_bytes,
265 &frame_byte_starts,
266 level,
267 values_per_frame,
268 n_values,
269 )?;
270
271 let metadata = ZstdMetadata {
272 dictionary_size: dictionary
273 .as_ref()
274 .map_or(0, |dict| dict.len())
275 .try_into()?,
276 frames: frame_metas,
277 };
278
279 Ok(ZstdArray::new(
280 dictionary,
281 frames,
282 dtype,
283 metadata,
284 parray.len(),
285 parray.validity().clone(),
286 ))
287 }
288
289 pub fn from_var_bin_view(
290 vbv: &VarBinViewArray,
291 level: i32,
292 values_per_frame: usize,
293 ) -> VortexResult<Self> {
294 let dtype = vbv.dtype().clone();
300
301 let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
303 let n_values = value_byte_indices.len();
304 let values_per_frame = if values_per_frame > 0 {
305 values_per_frame
306 } else {
307 n_values
308 };
309
310 let frame_byte_starts = (0..n_values)
311 .step_by(values_per_frame)
312 .map(|i| value_byte_indices[i])
313 .collect::<Vec<_>>();
314 let Frames {
315 dictionary,
316 frames,
317 frame_metas,
318 } = Self::compress_values(
319 &value_bytes,
320 &frame_byte_starts,
321 level,
322 values_per_frame,
323 n_values,
324 )?;
325
326 let metadata = ZstdMetadata {
327 dictionary_size: dictionary
328 .as_ref()
329 .map_or(0, |dict| dict.len())
330 .try_into()?,
331 frames: frame_metas,
332 };
333 Ok(ZstdArray::new(
334 dictionary,
335 frames,
336 dtype,
337 metadata,
338 vbv.len(),
339 vbv.validity().clone(),
340 ))
341 }
342
343 pub fn from_canonical(
344 canonical: &Canonical,
345 level: i32,
346 values_per_frame: usize,
347 ) -> VortexResult<Option<Self>> {
348 match canonical {
349 Canonical::Primitive(parray) => Ok(Some(ZstdArray::from_primitive(
350 parray,
351 level,
352 values_per_frame,
353 )?)),
354 Canonical::VarBinView(vbv) => Ok(Some(ZstdArray::from_var_bin_view(
355 vbv,
356 level,
357 values_per_frame,
358 )?)),
359 _ => Ok(None),
360 }
361 }
362
363 pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
364 Self::from_canonical(&array.to_canonical(), level, values_per_frame)?
365 .ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
366 }
367
368 fn byte_width(&self) -> usize {
369 if self.dtype.is_primitive() {
370 self.dtype.as_ptype().byte_width()
371 } else {
372 1
373 }
374 }
375
376 pub fn decompress(&self) -> ArrayRef {
377 let byte_width = self.byte_width();
380 let slice_n_rows = self.slice_stop - self.slice_start;
381 let slice_value_indices = self
382 .unsliced_validity
383 .to_mask(self.unsliced_n_rows)
384 .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
385
386 let slice_value_idx_start = slice_value_indices[0];
387 let slice_value_idx_stop = slice_value_indices[1];
388
389 let mut frames_to_decompress = vec![];
390 let mut value_idx_start = 0;
391 let mut uncompressed_size_to_decompress = 0;
392 let mut n_skipped_values = 0;
393 for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
394 if value_idx_start >= slice_value_idx_stop {
395 break;
396 }
397
398 let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
399 .vortex_expect("Uncompressed size must fit in usize");
400 let frame_n_values = if frame_meta.n_values == 0 {
401 frame_uncompressed_size / byte_width
403 } else {
404 usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
405 };
406
407 let value_idx_stop = value_idx_start + frame_n_values;
408 if value_idx_stop > slice_value_idx_start {
409 frames_to_decompress.push(frame);
411 uncompressed_size_to_decompress += frame_uncompressed_size;
412 } else {
413 n_skipped_values += frame_n_values;
414 }
415 value_idx_start = value_idx_stop;
416 }
417
418 let mut decompressor = if let Some(dictionary) = &self.dictionary {
420 zstd::bulk::Decompressor::with_dictionary(dictionary)
421 } else {
422 zstd::bulk::Decompressor::new()
423 }
424 .vortex_expect("Decompressor encountered io error");
425 let mut decompressed = ByteBufferMut::with_capacity_aligned(
426 uncompressed_size_to_decompress,
427 Alignment::new(byte_width),
428 );
429 unsafe {
430 decompressed.set_len(uncompressed_size_to_decompress);
433 }
434 let mut uncompressed_start = 0;
435 for frame in frames_to_decompress {
436 let uncompressed_written = decompressor
437 .decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])
438 .vortex_expect("error while decompressing zstd array");
439 uncompressed_start += uncompressed_written;
440 }
441 if uncompressed_start != uncompressed_size_to_decompress {
442 vortex_panic!(
443 "Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
444 uncompressed_size_to_decompress,
445 uncompressed_start
446 );
447 }
448
449 let decompressed = decompressed.freeze();
450 let slice_validity = self
452 .unsliced_validity
453 .slice(self.slice_start..self.slice_stop);
454
455 match &self.dtype {
456 DType::Primitive(..) => {
457 let slice_values_buffer = decompressed.slice(
458 (slice_value_idx_start - n_skipped_values) * byte_width
459 ..(slice_value_idx_stop - n_skipped_values) * byte_width,
460 );
461 let primitive = PrimitiveArray::from_values_byte_buffer(
462 slice_values_buffer,
463 self.dtype.as_ptype(),
464 slice_validity,
465 slice_n_rows,
466 );
467
468 primitive.into_array()
469 }
470 DType::Binary(_) | DType::Utf8(_) => {
471 let views = reconstruct_views(decompressed.clone()).slice(
475 slice_value_idx_start - n_skipped_values
476 ..slice_value_idx_stop - n_skipped_values,
477 );
478
479 let vbv = unsafe {
481 VarBinViewArray::new_unchecked(
482 views,
483 Arc::from([decompressed]),
484 self.dtype.clone(),
485 slice_validity,
486 )
487 };
488 vbv.into_array()
489 }
490 _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype),
491 }
492 }
493
494 pub(crate) fn _slice(&self, start: usize, stop: usize) -> ZstdArray {
495 ZstdArray {
496 slice_start: self.slice_start + start,
497 slice_stop: self.slice_start + stop,
498 stats_set: Default::default(),
499 ..self.clone()
500 }
501 }
502
503 pub(crate) fn dtype(&self) -> &DType {
504 &self.dtype
505 }
506
507 pub(crate) fn slice_start(&self) -> usize {
508 self.slice_start
509 }
510
511 pub(crate) fn slice_stop(&self) -> usize {
512 self.slice_stop
513 }
514
515 pub(crate) fn unsliced_n_rows(&self) -> usize {
516 self.unsliced_n_rows
517 }
518}
519
520impl ValiditySliceHelper for ZstdArray {
521 fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
522 (&self.unsliced_validity, self.slice_start, self.slice_stop)
523 }
524}
525
526impl ArrayVTable<ZstdVTable> for ZstdVTable {
527 fn len(array: &ZstdArray) -> usize {
528 array.slice_stop - array.slice_start
529 }
530
531 fn dtype(array: &ZstdArray) -> &DType {
532 &array.dtype
533 }
534
535 fn stats(array: &ZstdArray) -> StatsSetRef<'_> {
536 array.stats_set.to_ref(array.as_ref())
537 }
538}
539
540impl CanonicalVTable<ZstdVTable> for ZstdVTable {
541 fn canonicalize(array: &ZstdArray) -> Canonical {
542 array.decompress().to_canonical()
543 }
544}
545
546impl OperationsVTable<ZstdVTable> for ZstdVTable {
547 fn slice(array: &ZstdArray, range: Range<usize>) -> ArrayRef {
548 array._slice(range.start, range.end).into_array()
549 }
550
551 fn scalar_at(array: &ZstdArray, index: usize) -> Scalar {
552 array._slice(index, index + 1).decompress().scalar_at(0)
553 }
554}