vortex_pco/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp;
5use std::fmt::Debug;
6use std::hash::Hash;
7use std::ops::Range;
8
9use pco::ChunkConfig;
10use pco::PagingSpec;
11use pco::data_types::Number;
12use pco::data_types::NumberType;
13use pco::errors::PcoError;
14use pco::match_number_enum;
15use pco::wrapped::ChunkDecompressor;
16use pco::wrapped::FileCompressor;
17use pco::wrapped::FileDecompressor;
18use prost::Message;
19use vortex_array::ArrayBufferVisitor;
20use vortex_array::ArrayChildVisitor;
21use vortex_array::ArrayEq;
22use vortex_array::ArrayHash;
23use vortex_array::ArrayRef;
24use vortex_array::Canonical;
25use vortex_array::IntoArray;
26use vortex_array::Precision;
27use vortex_array::ProstMetadata;
28use vortex_array::ToCanonical;
29use vortex_array::arrays::PrimitiveArray;
30use vortex_array::arrays::PrimitiveVTable;
31use vortex_array::buffer::BufferHandle;
32use vortex_array::compute::filter;
33use vortex_array::serde::ArrayChildren;
34use vortex_array::stats::ArrayStats;
35use vortex_array::stats::StatsSetRef;
36use vortex_array::validity::Validity;
37use vortex_array::vtable;
38use vortex_array::vtable::ArrayId;
39use vortex_array::vtable::ArrayVTable;
40use vortex_array::vtable::ArrayVTableExt;
41use vortex_array::vtable::BaseArrayVTable;
42use vortex_array::vtable::CanonicalVTable;
43use vortex_array::vtable::EncodeVTable;
44use vortex_array::vtable::NotSupported;
45use vortex_array::vtable::OperationsVTable;
46use vortex_array::vtable::VTable;
47use vortex_array::vtable::ValidityHelper;
48use vortex_array::vtable::ValiditySliceHelper;
49use vortex_array::vtable::ValidityVTableFromValiditySliceHelper;
50use vortex_array::vtable::VisitorVTable;
51use vortex_buffer::BufferMut;
52use vortex_buffer::ByteBuffer;
53use vortex_buffer::ByteBufferMut;
54use vortex_dtype::DType;
55use vortex_dtype::PType;
56use vortex_dtype::half;
57use vortex_error::VortexError;
58use vortex_error::VortexExpect;
59use vortex_error::VortexResult;
60use vortex_error::vortex_bail;
61use vortex_error::vortex_ensure;
62use vortex_error::vortex_err;
63use vortex_scalar::Scalar;
64
65use crate::PcoChunkInfo;
66use crate::PcoMetadata;
67use crate::PcoPageInfo;
68
69// Overall approach here:
70// Chunk the array into Pco chunks (currently using the default recommended size
71// for good compression), and into finer-grained Pco pages. As we go, write each
72// ChunkMeta as a buffer, followed by each of that chunk's pages as a buffer. We
73// store metadata for each of these "components" (chunk or page). At
74// decompression time, we figure out which components we need to read and only
75// process those. We only compress and decompress valid values.
76
77// Visually, during decompression, we have an interval of pages we're
78// decompressing and a tighter interval of the slice we actually care about.
79// |=============values (all valid elements)==============|
80// |<-n_skipped_values->|----decompressed_values------|
81//                          |----slice_values----|
82//                          ^                    ^
83// |<---slice_value_start-->|<--slice_n_values-->|
84// We then insert these values to the correct position using a primitive array
85// constructor.
86
87const VALUES_PER_CHUNK: usize = pco::DEFAULT_MAX_PAGE_N;
88
89vtable!(Pco);
90
91impl VTable for PcoVTable {
92    type Array = PcoArray;
93
94    type Metadata = ProstMetadata<PcoMetadata>;
95
96    type ArrayVTable = Self;
97    type CanonicalVTable = Self;
98    type OperationsVTable = Self;
99    type ValidityVTable = ValidityVTableFromValiditySliceHelper;
100    type VisitorVTable = Self;
101    type ComputeVTable = NotSupported;
102    type EncodeVTable = Self;
103
104    fn id(&self) -> ArrayId {
105        ArrayId::new_ref("vortex.pco")
106    }
107
108    fn encoding(_array: &Self::Array) -> ArrayVTable {
109        PcoVTable.as_vtable()
110    }
111
112    fn metadata(array: &PcoArray) -> VortexResult<Self::Metadata> {
113        Ok(ProstMetadata(array.metadata.clone()))
114    }
115
116    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
117        Ok(Some(metadata.0.encode_to_vec()))
118    }
119
120    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
121        Ok(ProstMetadata(PcoMetadata::decode(buffer)?))
122    }
123
124    fn build(
125        &self,
126        dtype: &DType,
127        len: usize,
128        metadata: &Self::Metadata,
129        buffers: &[BufferHandle],
130        children: &dyn ArrayChildren,
131    ) -> VortexResult<PcoArray> {
132        let validity = if children.is_empty() {
133            Validity::from(dtype.nullability())
134        } else if children.len() == 1 {
135            let validity = children.get(0, &Validity::DTYPE, len)?;
136            Validity::Array(validity)
137        } else {
138            vortex_bail!("PcoArray expected 0 or 1 child, got {}", children.len());
139        };
140
141        vortex_ensure!(buffers.len() >= metadata.0.chunks.len());
142        let chunk_metas = buffers[..metadata.0.chunks.len()]
143            .iter()
144            .map(|b| b.clone().try_to_bytes())
145            .collect::<VortexResult<Vec<_>>>()?;
146        let pages = buffers[metadata.0.chunks.len()..]
147            .iter()
148            .map(|b| b.clone().try_to_bytes())
149            .collect::<VortexResult<Vec<_>>>()?;
150
151        let expected_n_pages = metadata
152            .0
153            .chunks
154            .iter()
155            .map(|info| info.pages.len())
156            .sum::<usize>();
157        vortex_ensure!(pages.len() == expected_n_pages);
158
159        Ok(PcoArray::new(
160            chunk_metas,
161            pages,
162            dtype.clone(),
163            metadata.0.clone(),
164            len,
165            validity,
166        ))
167    }
168
169    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
170        vortex_ensure!(
171            children.len() <= 1,
172            "PcoArray expects 0 or 1 children, got {}",
173            children.len()
174        );
175
176        if children.is_empty() {
177            array.unsliced_validity = Validity::from(array.dtype.nullability());
178        } else {
179            array.unsliced_validity =
180                Validity::Array(children.into_iter().next().vortex_expect("validity child"));
181        }
182
183        Ok(())
184    }
185}
186
187pub(crate) fn number_type_from_dtype(dtype: &DType) -> NumberType {
188    let ptype = dtype.as_ptype();
189    match ptype {
190        PType::F16 => NumberType::F16,
191        PType::F32 => NumberType::F32,
192        PType::F64 => NumberType::F64,
193        PType::I16 => NumberType::I16,
194        PType::I32 => NumberType::I32,
195        PType::I64 => NumberType::I64,
196        PType::U16 => NumberType::U16,
197        PType::U32 => NumberType::U32,
198        PType::U64 => NumberType::U64,
199        _ => unreachable!("PType not supported by Pco: {:?}", ptype),
200    }
201}
202
203fn collect_valid(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
204    let mask = parray.validity_mask();
205    Ok(filter(&parray.to_array(), &mask)?.to_primitive())
206}
207
208pub(crate) fn vortex_err_from_pco(err: PcoError) -> VortexError {
209    use pco::errors::ErrorKind::*;
210    match err.kind {
211        Io(io_kind) => VortexError::from(std::io::Error::new(io_kind, err.message)),
212        InvalidArgument => vortex_err!(InvalidArgument: "{}", err.message),
213        other => vortex_err!("Pco {:?} error: {}", other, err.message),
214    }
215}
216
217#[derive(Debug)]
218pub struct PcoVTable;
219
220#[derive(Clone, Debug)]
221pub struct PcoArray {
222    pub(crate) chunk_metas: Vec<ByteBuffer>,
223    pub(crate) pages: Vec<ByteBuffer>,
224    pub(crate) metadata: PcoMetadata,
225    dtype: DType,
226    pub(crate) unsliced_validity: Validity,
227    unsliced_n_rows: usize,
228    stats_set: ArrayStats,
229    slice_start: usize,
230    slice_stop: usize,
231}
232
233impl PcoArray {
234    pub fn new(
235        chunk_metas: Vec<ByteBuffer>,
236        pages: Vec<ByteBuffer>,
237        dtype: DType,
238        metadata: PcoMetadata,
239        len: usize,
240        validity: Validity,
241    ) -> Self {
242        Self {
243            chunk_metas,
244            pages,
245            metadata,
246            dtype,
247            unsliced_validity: validity,
248            unsliced_n_rows: len,
249            stats_set: Default::default(),
250            slice_start: 0,
251            slice_stop: len,
252        }
253    }
254
255    pub fn from_primitive(
256        parray: &PrimitiveArray,
257        level: usize,
258        values_per_page: usize,
259    ) -> VortexResult<Self> {
260        Self::from_primitive_with_values_per_chunk(parray, level, VALUES_PER_CHUNK, values_per_page)
261    }
262
263    pub(crate) fn from_primitive_with_values_per_chunk(
264        parray: &PrimitiveArray,
265        level: usize,
266        values_per_chunk: usize,
267        values_per_page: usize,
268    ) -> VortexResult<Self> {
269        let number_type = number_type_from_dtype(parray.dtype());
270        let values_per_page = if values_per_page == 0 {
271            values_per_chunk
272        } else {
273            values_per_page
274        };
275
276        // perhaps one day we can make this more configurable
277        let chunk_config = ChunkConfig::default()
278            .with_compression_level(level)
279            .with_paging_spec(PagingSpec::EqualPagesUpTo(values_per_page));
280
281        let values = collect_valid(parray)?;
282        let n_values = values.len();
283
284        let fc = FileCompressor::default();
285        let mut header = vec![];
286        fc.write_header(&mut header).map_err(vortex_err_from_pco)?;
287
288        let mut chunk_meta_buffers = vec![]; // the Pco component
289        let mut chunk_infos = vec![]; // the Vortex metadata
290        let mut page_buffers = vec![];
291        for chunk_start in (0..n_values).step_by(values_per_chunk) {
292            let cc = match_number_enum!(
293                number_type,
294                NumberType<T> => {
295                    let chunk_end = cmp::min(n_values, chunk_start + values_per_chunk);
296                    let values = values.buffer::<T>();
297                    let chunk = &values.as_slice()[chunk_start..chunk_end];
298                    fc
299                        .chunk_compressor(chunk, &chunk_config)
300                        .map_err(vortex_err_from_pco)?
301                }
302            );
303
304            let mut chunk_meta_buffer = ByteBufferMut::with_capacity(cc.chunk_meta_size_hint());
305            cc.write_chunk_meta(&mut chunk_meta_buffer)
306                .map_err(vortex_err_from_pco)?;
307            chunk_meta_buffers.push(chunk_meta_buffer.freeze());
308
309            let mut page_infos = vec![];
310            for (page_idx, page_n_values) in cc.n_per_page().into_iter().enumerate() {
311                let mut page = ByteBufferMut::with_capacity(cc.page_size_hint(page_idx));
312                cc.write_page(page_idx, &mut page)
313                    .map_err(vortex_err_from_pco)?;
314                page_buffers.push(page.freeze());
315                page_infos.push(PcoPageInfo {
316                    n_values: u32::try_from(page_n_values)?,
317                });
318            }
319            chunk_infos.push(PcoChunkInfo { pages: page_infos })
320        }
321
322        let metadata = PcoMetadata {
323            header,
324            chunks: chunk_infos,
325        };
326        Ok(PcoArray::new(
327            chunk_meta_buffers,
328            page_buffers,
329            parray.dtype().clone(),
330            metadata,
331            parray.len(),
332            parray.validity().clone(),
333        ))
334    }
335
336    pub fn from_array(array: ArrayRef, level: usize, nums_per_page: usize) -> VortexResult<Self> {
337        if let Some(parray) = array.as_opt::<PrimitiveVTable>() {
338            Self::from_primitive(parray, level, nums_per_page)
339        } else {
340            Err(vortex_err!("Pco can only encode primitive arrays"))
341        }
342    }
343
344    pub fn decompress(&self) -> PrimitiveArray {
345        // To start, we figure out which chunks and pages we need to decompress, and with
346        // what value offset into the first such page.
347        let number_type = number_type_from_dtype(&self.dtype);
348        let values_byte_buffer = match_number_enum!(
349            number_type,
350            NumberType<T> => {
351              self.decompress_values_typed::<T>()
352            }
353        );
354
355        PrimitiveArray::from_values_byte_buffer(
356            values_byte_buffer,
357            self.dtype.as_ptype(),
358            self.unsliced_validity
359                .slice(self.slice_start..self.slice_stop),
360            self.slice_stop - self.slice_start,
361        )
362    }
363
364    #[allow(clippy::unwrap_in_result, clippy::unwrap_used)]
365    fn decompress_values_typed<T: Number>(&self) -> ByteBuffer {
366        // To start, we figure out what range of values we need to decompress.
367        let slice_value_indices = self
368            .unsliced_validity
369            .to_mask(self.unsliced_n_rows)
370            .valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
371        let slice_value_start = slice_value_indices[0];
372        let slice_value_stop = slice_value_indices[1];
373        let slice_n_values = slice_value_stop - slice_value_start;
374
375        // Then we decompress those pages into a buffer. Note that these values
376        // may exceed the bounds of the slice, so we need to slice later.
377        let (fd, _) = FileDecompressor::new(self.metadata.header.as_slice())
378            .map_err(vortex_err_from_pco)
379            .vortex_expect("FileDecompressor::new should succeed with valid header");
380        let mut decompressed_values = BufferMut::<T>::with_capacity(slice_n_values);
381        let mut page_idx = 0;
382        let mut page_value_start = 0;
383        let mut n_skipped_values = 0;
384        for (chunk_info, chunk_meta) in self.metadata.chunks.iter().zip(&self.chunk_metas) {
385            let mut cd: Option<ChunkDecompressor<T>> = None;
386            for page_info in &chunk_info.pages {
387                let page_n_values = page_info.n_values as usize;
388                let page_value_stop = page_value_start + page_n_values;
389
390                if page_value_start >= slice_value_stop {
391                    break;
392                }
393
394                if page_value_stop > slice_value_start {
395                    // we need this page
396                    let old_len = decompressed_values.len();
397                    let new_len = old_len + page_n_values;
398                    decompressed_values.reserve(page_n_values);
399                    unsafe {
400                        decompressed_values.set_len(new_len);
401                    }
402                    let chunk_meta_bytes: &[u8] = chunk_meta.as_ref();
403                    let page: &[u8] = self.pages[page_idx].as_ref();
404                    if cd.is_none() {
405                        let (new_cd, _) = fd
406                            .chunk_decompressor(chunk_meta_bytes)
407                            .map_err(vortex_err_from_pco)
408                            .vortex_expect(
409                                "chunk_decompressor should succeed with valid chunk metadata",
410                            );
411                        cd = Some(new_cd);
412                    }
413                    let mut pd = cd
414                        .as_mut()
415                        .unwrap()
416                        .page_decompressor(page, page_n_values)
417                        .map_err(vortex_err_from_pco)
418                        .vortex_expect("page_decompressor should succeed with valid page data");
419                    pd.decompress(&mut decompressed_values[old_len..new_len])
420                        .map_err(vortex_err_from_pco)
421                        .vortex_expect("decompress should succeed with valid compressed data");
422                } else {
423                    n_skipped_values += page_n_values;
424                }
425
426                page_value_start = page_value_stop;
427                page_idx += 1;
428            }
429        }
430
431        // Slice only the values requested.
432        let value_offset = slice_value_start - n_skipped_values;
433        decompressed_values
434            .freeze()
435            .slice(value_offset..value_offset + slice_n_values)
436            .into_byte_buffer()
437    }
438
439    pub(crate) fn _slice(&self, start: usize, stop: usize) -> Self {
440        PcoArray {
441            slice_start: self.slice_start + start,
442            slice_stop: self.slice_start + stop,
443            stats_set: Default::default(),
444            ..self.clone()
445        }
446    }
447
448    pub(crate) fn dtype(&self) -> &DType {
449        &self.dtype
450    }
451
452    pub(crate) fn slice_start(&self) -> usize {
453        self.slice_start
454    }
455
456    pub(crate) fn slice_stop(&self) -> usize {
457        self.slice_stop
458    }
459
460    pub(crate) fn unsliced_n_rows(&self) -> usize {
461        self.unsliced_n_rows
462    }
463}
464
465impl ValiditySliceHelper for PcoArray {
466    fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
467        (&self.unsliced_validity, self.slice_start, self.slice_stop)
468    }
469}
470
471impl BaseArrayVTable<PcoVTable> for PcoVTable {
472    fn len(array: &PcoArray) -> usize {
473        array.slice_stop - array.slice_start
474    }
475
476    fn dtype(array: &PcoArray) -> &DType {
477        &array.dtype
478    }
479
480    fn stats(array: &PcoArray) -> StatsSetRef<'_> {
481        array.stats_set.to_ref(array.as_ref())
482    }
483
484    fn array_hash<H: std::hash::Hasher>(array: &PcoArray, state: &mut H, precision: Precision) {
485        array.dtype.hash(state);
486        array.unsliced_validity.array_hash(state, precision);
487        array.unsliced_n_rows.hash(state);
488        array.slice_start.hash(state);
489        array.slice_stop.hash(state);
490        // Hash chunk_metas and pages using pointer-based hashing
491        for chunk_meta in &array.chunk_metas {
492            chunk_meta.array_hash(state, precision);
493        }
494        for page in &array.pages {
495            page.array_hash(state, precision);
496        }
497    }
498
499    fn array_eq(array: &PcoArray, other: &PcoArray, precision: Precision) -> bool {
500        if array.dtype != other.dtype
501            || !array
502                .unsliced_validity
503                .array_eq(&other.unsliced_validity, precision)
504            || array.unsliced_n_rows != other.unsliced_n_rows
505            || array.slice_start != other.slice_start
506            || array.slice_stop != other.slice_stop
507            || array.chunk_metas.len() != other.chunk_metas.len()
508            || array.pages.len() != other.pages.len()
509        {
510            return false;
511        }
512        for (a, b) in array.chunk_metas.iter().zip(&other.chunk_metas) {
513            if !a.array_eq(b, precision) {
514                return false;
515            }
516        }
517        for (a, b) in array.pages.iter().zip(&other.pages) {
518            if !a.array_eq(b, precision) {
519                return false;
520            }
521        }
522        true
523    }
524}
525
526impl CanonicalVTable<PcoVTable> for PcoVTable {
527    fn canonicalize(array: &PcoArray) -> Canonical {
528        array.decompress().to_canonical()
529    }
530}
531
532impl OperationsVTable<PcoVTable> for PcoVTable {
533    fn slice(array: &PcoArray, range: Range<usize>) -> ArrayRef {
534        array._slice(range.start, range.end).into_array()
535    }
536
537    fn scalar_at(array: &PcoArray, index: usize) -> Scalar {
538        array._slice(index, index + 1).decompress().scalar_at(0)
539    }
540}
541
542impl EncodeVTable<PcoVTable> for PcoVTable {
543    fn encode(
544        _vtable: &PcoVTable,
545        canonical: &Canonical,
546        _like: Option<&PcoArray>,
547    ) -> VortexResult<Option<PcoArray>> {
548        let parray = canonical.clone().into_primitive();
549
550        Ok(Some(PcoArray::from_primitive(&parray, 3, 0)?))
551    }
552}
553
554impl VisitorVTable<PcoVTable> for PcoVTable {
555    fn visit_buffers(array: &PcoArray, visitor: &mut dyn ArrayBufferVisitor) {
556        for buffer in &array.chunk_metas {
557            visitor.visit_buffer(buffer);
558        }
559        for buffer in &array.pages {
560            visitor.visit_buffer(buffer);
561        }
562    }
563
564    fn visit_children(array: &PcoArray, visitor: &mut dyn ArrayChildVisitor) {
565        visitor.visit_validity(&array.unsliced_validity, array.unsliced_n_rows());
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use vortex_array::IntoArray;
572    use vortex_array::ToCanonical;
573    use vortex_array::arrays::PrimitiveArray;
574    use vortex_array::assert_arrays_eq;
575    use vortex_array::validity::Validity;
576    use vortex_buffer::Buffer;
577
578    use crate::PcoArray;
579
580    #[test]
581    fn test_slice_nullable() {
582        // Create a nullable array with some nulls
583        let values = PrimitiveArray::new(
584            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
585            Validity::from_iter([false, true, true, true, true, false]),
586        );
587        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
588        let decoded = pco.to_primitive();
589        assert_arrays_eq!(
590            decoded,
591            PrimitiveArray::from_option_iter([
592                None,
593                Some(20u32),
594                Some(30),
595                Some(40),
596                Some(50),
597                None
598            ])
599        );
600
601        // Slice to get only the non-null values in the middle
602        let sliced = pco.slice(1..5);
603        let expected =
604            PrimitiveArray::from_option_iter([Some(20u32), Some(30), Some(40), Some(50)])
605                .into_array();
606        assert_arrays_eq!(sliced, expected);
607        assert_arrays_eq!(sliced.to_canonical().into_array(), expected);
608    }
609}