vortex_pco/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp;
5use std::fmt::Debug;
6use std::hash::Hash;
7use std::ops::Range;
8
9use pco::data_types::{Number, NumberType};
10use pco::errors::PcoError;
11use pco::wrapped::{ChunkDecompressor, FileCompressor, FileDecompressor};
12use pco::{ChunkConfig, PagingSpec, match_number_enum};
13use vortex_array::arrays::{PrimitiveArray, PrimitiveVTable};
14use vortex_array::compute::filter;
15use vortex_array::stats::{ArrayStats, StatsSetRef};
16use vortex_array::validity::Validity;
17use vortex_array::vtable::{
18    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
19    ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
20};
21use vortex_array::{
22    ArrayEq, ArrayHash, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, Precision,
23    ToCanonical, vtable,
24};
25use vortex_buffer::{BufferMut, ByteBuffer, ByteBufferMut};
26use vortex_dtype::{DType, PType, half};
27use vortex_error::{VortexError, VortexResult, VortexUnwrap, vortex_err};
28use vortex_scalar::Scalar;
29
30use crate::serde::PcoMetadata;
31use crate::{PcoChunkInfo, PcoPageInfo};
32
33// Overall approach here:
34// Chunk the array into Pco chunks (currently using the default recommended size
35// for good compression), and into finer-grained Pco pages. As we go, write each
36// ChunkMeta as a buffer, followed by each of that chunk's pages as a buffer. We
37// store metadata for each of these "components" (chunk or page). At
38// decompression time, we figure out which components we need to read and only
39// process those. We only compress and decompress valid values.
40
41// Visually, during decompression, we have an interval of pages we're
42// decompressing and a tighter interval of the slice we actually care about.
43// |=============values (all valid elements)==============|
44// |<-n_skipped_values->|----decompressed_values------|
45//                          |----slice_values----|
46//                          ^                    ^
47// |<---slice_value_start-->|<--slice_n_values-->|
48// We then insert these values to the correct position using a primitive array
49// constructor.
50
51const VALUES_PER_CHUNK: usize = pco::DEFAULT_MAX_PAGE_N;
52
53vtable!(Pco);
54
55impl VTable for PcoVTable {
56    type Array = PcoArray;
57    type Encoding = PcoEncoding;
58
59    type ArrayVTable = Self;
60    type CanonicalVTable = Self;
61    type OperationsVTable = Self;
62    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
63    type VisitorVTable = Self;
64    type ComputeVTable = NotSupported;
65    type EncodeVTable = Self;
66    type SerdeVTable = Self;
67    type OperatorVTable = NotSupported;
68
69    fn id(_encoding: &Self::Encoding) -> EncodingId {
70        EncodingId::new_ref("vortex.pco")
71    }
72
73    fn encoding(_array: &Self::Array) -> EncodingRef {
74        EncodingRef::new_ref(PcoEncoding.as_ref())
75    }
76}
77
78fn number_type_from_dtype(dtype: &DType) -> NumberType {
79    let ptype = dtype.as_ptype();
80    match ptype {
81        PType::F16 => NumberType::F16,
82        PType::F32 => NumberType::F32,
83        PType::F64 => NumberType::F64,
84        PType::I16 => NumberType::I16,
85        PType::I32 => NumberType::I32,
86        PType::I64 => NumberType::I64,
87        PType::U16 => NumberType::U16,
88        PType::U32 => NumberType::U32,
89        PType::U64 => NumberType::U64,
90        _ => unreachable!("PType not supported by Pco: {:?}", ptype),
91    }
92}
93
94fn collect_valid(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
95    let mask = parray.validity_mask();
96    Ok(filter(&parray.to_array(), &mask)?.to_primitive())
97}
98
99fn vortex_err_from_pco(err: PcoError) -> VortexError {
100    use pco::errors::ErrorKind::*;
101    match err.kind {
102        Io(io_kind) => VortexError::from(std::io::Error::new(io_kind, err.message)),
103        InvalidArgument => vortex_err!(InvalidArgument: "{}", err.message),
104        other => vortex_err!("Pco {:?} error: {}", other, err.message),
105    }
106}
107
108#[derive(Clone, Debug)]
109pub struct PcoEncoding;
110
111#[derive(Clone, Debug)]
112pub struct PcoArray {
113    pub(crate) chunk_metas: Vec<ByteBuffer>,
114    pub(crate) pages: Vec<ByteBuffer>,
115    pub(crate) metadata: PcoMetadata,
116    dtype: DType,
117    pub(crate) unsliced_validity: Validity,
118    unsliced_n_rows: usize,
119    stats_set: ArrayStats,
120    slice_start: usize,
121    slice_stop: usize,
122}
123
124impl PcoArray {
125    pub fn new(
126        chunk_metas: Vec<ByteBuffer>,
127        pages: Vec<ByteBuffer>,
128        dtype: DType,
129        metadata: PcoMetadata,
130        len: usize,
131        validity: Validity,
132    ) -> Self {
133        Self {
134            chunk_metas,
135            pages,
136            metadata,
137            dtype,
138            unsliced_validity: validity,
139            unsliced_n_rows: len,
140            stats_set: Default::default(),
141            slice_start: 0,
142            slice_stop: len,
143        }
144    }
145
146    pub fn from_primitive(
147        parray: &PrimitiveArray,
148        level: usize,
149        values_per_page: usize,
150    ) -> VortexResult<Self> {
151        Self::from_primitive_with_values_per_chunk(parray, level, VALUES_PER_CHUNK, values_per_page)
152    }
153
154    pub(crate) fn from_primitive_with_values_per_chunk(
155        parray: &PrimitiveArray,
156        level: usize,
157        values_per_chunk: usize,
158        values_per_page: usize,
159    ) -> VortexResult<Self> {
160        let number_type = number_type_from_dtype(parray.dtype());
161        let values_per_page = if values_per_page == 0 {
162            values_per_chunk
163        } else {
164            values_per_page
165        };
166
167        // perhaps one day we can make this more configurable
168        let chunk_config = ChunkConfig::default()
169            .with_compression_level(level)
170            .with_paging_spec(PagingSpec::EqualPagesUpTo(values_per_page));
171
172        let values = collect_valid(parray)?;
173        let n_values = values.len();
174
175        let fc = FileCompressor::default();
176        let mut header = vec![];
177        fc.write_header(&mut header).map_err(vortex_err_from_pco)?;
178
179        let mut chunk_meta_buffers = vec![]; // the Pco component
180        let mut chunk_infos = vec![]; // the Vortex metadata
181        let mut page_buffers = vec![];
182        for chunk_start in (0..n_values).step_by(values_per_chunk) {
183            let cc = match_number_enum!(
184                number_type,
185                NumberType<T> => {
186                    let chunk_end = cmp::min(n_values, chunk_start + values_per_chunk);
187                    let values = values.buffer::<T>();
188                    let chunk = &values.as_slice()[chunk_start..chunk_end];
189                    fc
190                        .chunk_compressor(chunk, &chunk_config)
191                        .map_err(vortex_err_from_pco)?
192                }
193            );
194
195            let mut chunk_meta_buffer = ByteBufferMut::with_capacity(cc.chunk_meta_size_hint());
196            cc.write_chunk_meta(&mut chunk_meta_buffer)
197                .map_err(vortex_err_from_pco)?;
198            chunk_meta_buffers.push(chunk_meta_buffer.freeze());
199
200            let mut page_infos = vec![];
201            for (page_idx, page_n_values) in cc.n_per_page().into_iter().enumerate() {
202                let mut page = ByteBufferMut::with_capacity(cc.page_size_hint(page_idx));
203                cc.write_page(page_idx, &mut page)
204                    .map_err(vortex_err_from_pco)?;
205                page_buffers.push(page.freeze());
206                page_infos.push(PcoPageInfo {
207                    n_values: u32::try_from(page_n_values)?,
208                });
209            }
210            chunk_infos.push(PcoChunkInfo { pages: page_infos })
211        }
212
213        let metadata = PcoMetadata {
214            header,
215            chunks: chunk_infos,
216        };
217        Ok(PcoArray::new(
218            chunk_meta_buffers,
219            page_buffers,
220            parray.dtype().clone(),
221            metadata,
222            parray.len(),
223            parray.validity().clone(),
224        ))
225    }
226
227    pub fn from_array(array: ArrayRef, level: usize, nums_per_page: usize) -> VortexResult<Self> {
228        if let Some(parray) = array.as_opt::<PrimitiveVTable>() {
229            Self::from_primitive(parray, level, nums_per_page)
230        } else {
231            Err(vortex_err!("Pco can only encode primitive arrays"))
232        }
233    }
234
235    pub fn decompress(&self) -> PrimitiveArray {
236        // To start, we figure out which chunks and pages we need to decompress, and with
237        // what value offset into the first such page.
238        let number_type = number_type_from_dtype(&self.dtype);
239        let values_byte_buffer = match_number_enum!(
240            number_type,
241            NumberType<T> => {
242              self.decompress_values_typed::<T>()
243            }
244        );
245
246        PrimitiveArray::from_values_byte_buffer(
247            values_byte_buffer,
248            self.dtype.as_ptype(),
249            self.unsliced_validity
250                .slice(self.slice_start..self.slice_stop),
251            self.slice_stop - self.slice_start,
252        )
253    }
254
255    #[allow(clippy::unwrap_in_result, clippy::unwrap_used)]
256    fn decompress_values_typed<T: Number>(&self) -> ByteBuffer {
257        // To start, we figure out what range of values we need to decompress.
258        let slice_value_indices = self
259            .unsliced_validity
260            .to_mask(self.unsliced_n_rows)
261            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
262        let slice_value_start = slice_value_indices[0];
263        let slice_value_stop = slice_value_indices[1];
264        let slice_n_values = slice_value_stop - slice_value_start;
265
266        // Then we decompress those pages into a buffer. Note that these values
267        // may exceed the bounds of the slice, so we need to slice later.
268        let (fd, _) = FileDecompressor::new(self.metadata.header.as_slice())
269            .map_err(vortex_err_from_pco)
270            .vortex_unwrap();
271        let mut decompressed_values = BufferMut::<T>::with_capacity(slice_n_values);
272        let mut page_idx = 0;
273        let mut page_value_start = 0;
274        let mut n_skipped_values = 0;
275        for (chunk_info, chunk_meta) in self.metadata.chunks.iter().zip(&self.chunk_metas) {
276            let mut cd: Option<ChunkDecompressor<T>> = None;
277            for page_info in &chunk_info.pages {
278                let page_n_values = page_info.n_values as usize;
279                let page_value_stop = page_value_start + page_n_values;
280
281                if page_value_start >= slice_value_stop {
282                    break;
283                }
284
285                if page_value_stop > slice_value_start {
286                    // we need this page
287                    let old_len = decompressed_values.len();
288                    let new_len = old_len + page_n_values;
289                    decompressed_values.reserve(page_n_values);
290                    unsafe {
291                        decompressed_values.set_len(new_len);
292                    }
293                    let chunk_meta_bytes: &[u8] = chunk_meta.as_ref();
294                    let page: &[u8] = self.pages[page_idx].as_ref();
295                    if cd.is_none() {
296                        let (new_cd, _) = fd
297                            .chunk_decompressor(chunk_meta_bytes)
298                            .map_err(vortex_err_from_pco)
299                            .vortex_unwrap();
300                        cd = Some(new_cd);
301                    }
302                    let mut pd = cd
303                        .as_mut()
304                        .unwrap()
305                        .page_decompressor(page, page_n_values)
306                        .map_err(vortex_err_from_pco)
307                        .vortex_unwrap();
308                    pd.decompress(&mut decompressed_values[old_len..new_len])
309                        .map_err(vortex_err_from_pco)
310                        .vortex_unwrap();
311                } else {
312                    n_skipped_values += page_n_values;
313                }
314
315                page_value_start = page_value_stop;
316                page_idx += 1;
317            }
318        }
319
320        // Slice only the values requested.
321        let value_offset = slice_value_start - n_skipped_values;
322        decompressed_values
323            .freeze()
324            .slice(value_offset..value_offset + slice_n_values)
325            .into_byte_buffer()
326    }
327
328    pub(crate) fn _slice(&self, start: usize, stop: usize) -> Self {
329        PcoArray {
330            slice_start: self.slice_start + start,
331            slice_stop: self.slice_start + stop,
332            stats_set: Default::default(),
333            ..self.clone()
334        }
335    }
336
337    pub(crate) fn dtype(&self) -> &DType {
338        &self.dtype
339    }
340
341    pub(crate) fn slice_start(&self) -> usize {
342        self.slice_start
343    }
344
345    pub(crate) fn slice_stop(&self) -> usize {
346        self.slice_stop
347    }
348
349    pub(crate) fn unsliced_n_rows(&self) -> usize {
350        self.unsliced_n_rows
351    }
352}
353
354impl ValiditySliceHelper for PcoArray {
355    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
356        (&self.unsliced_validity, self.slice_start, self.slice_stop)
357    }
358}
359
360impl ArrayVTable<PcoVTable> for PcoVTable {
361    fn len(array: &PcoArray) -> usize {
362        array.slice_stop - array.slice_start
363    }
364
365    fn dtype(array: &PcoArray) -> &DType {
366        &array.dtype
367    }
368
369    fn stats(array: &PcoArray) -> StatsSetRef<'_> {
370        array.stats_set.to_ref(array.as_ref())
371    }
372
373    fn array_hash<H: std::hash::Hasher>(array: &PcoArray, state: &mut H, precision: Precision) {
374        array.dtype.hash(state);
375        array.unsliced_validity.array_hash(state, precision);
376        array.unsliced_n_rows.hash(state);
377        array.slice_start.hash(state);
378        array.slice_stop.hash(state);
379        // Hash chunk_metas and pages using pointer-based hashing
380        for chunk_meta in &array.chunk_metas {
381            chunk_meta.array_hash(state, precision);
382        }
383        for page in &array.pages {
384            page.array_hash(state, precision);
385        }
386    }
387
388    fn array_eq(array: &PcoArray, other: &PcoArray, precision: Precision) -> bool {
389        if array.dtype != other.dtype
390            || !array
391                .unsliced_validity
392                .array_eq(&other.unsliced_validity, precision)
393            || array.unsliced_n_rows != other.unsliced_n_rows
394            || array.slice_start != other.slice_start
395            || array.slice_stop != other.slice_stop
396            || array.chunk_metas.len() != other.chunk_metas.len()
397            || array.pages.len() != other.pages.len()
398        {
399            return false;
400        }
401        for (a, b) in array.chunk_metas.iter().zip(&other.chunk_metas) {
402            if !a.array_eq(b, precision) {
403                return false;
404            }
405        }
406        for (a, b) in array.pages.iter().zip(&other.pages) {
407            if !a.array_eq(b, precision) {
408                return false;
409            }
410        }
411        true
412    }
413}
414
415impl CanonicalVTable<PcoVTable> for PcoVTable {
416    fn canonicalize(array: &PcoArray) -> Canonical {
417        array.decompress().to_canonical()
418    }
419}
420
421impl OperationsVTable<PcoVTable> for PcoVTable {
422    fn slice(array: &PcoArray, range: Range<usize>) -> ArrayRef {
423        array._slice(range.start, range.end).into_array()
424    }
425
426    fn scalar_at(array: &PcoArray, index: usize) -> Scalar {
427        array._slice(index, index + 1).decompress().scalar_at(0)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use vortex_array::arrays::PrimitiveArray;
434    use vortex_array::validity::Validity;
435    use vortex_array::{IntoArray, ToCanonical, assert_arrays_eq};
436    use vortex_buffer::Buffer;
437
438    use crate::PcoArray;
439
440    #[test]
441    fn test_slice_nullable() {
442        // Create a nullable array with some nulls
443        let values = PrimitiveArray::new(
444            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
445            Validity::from_iter([false, true, true, true, true, false]),
446        );
447        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
448        let decoded = pco.to_primitive();
449        assert_arrays_eq!(
450            decoded,
451            PrimitiveArray::from_option_iter([
452                None,
453                Some(20u32),
454                Some(30),
455                Some(40),
456                Some(50),
457                None
458            ])
459        );
460
461        // Slice to get only the non-null values in the middle
462        let sliced = pco.slice(1..5);
463        let expected =
464            PrimitiveArray::from_option_iter([Some(20u32), Some(30), Some(40), Some(50)])
465                .into_array();
466        assert_arrays_eq!(sliced, expected);
467        assert_arrays_eq!(sliced.to_canonical().into_array(), expected);
468    }
469}