vortex_alp/alp_rd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayRef;
14use vortex_array::Canonical;
15use vortex_array::DeserializeMetadata;
16use vortex_array::Precision;
17use vortex_array::ProstMetadata;
18use vortex_array::SerializeMetadata;
19use vortex_array::ToCanonical;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::patches::Patches;
23use vortex_array::patches::PatchesMetadata;
24use vortex_array::serde::ArrayChildren;
25use vortex_array::stats::ArrayStats;
26use vortex_array::stats::StatsSetRef;
27use vortex_array::validity::Validity;
28use vortex_array::vtable;
29use vortex_array::vtable::ArrayId;
30use vortex_array::vtable::ArrayVTable;
31use vortex_array::vtable::ArrayVTableExt;
32use vortex_array::vtable::BaseArrayVTable;
33use vortex_array::vtable::CanonicalVTable;
34use vortex_array::vtable::EncodeVTable;
35use vortex_array::vtable::NotSupported;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityChild;
38use vortex_array::vtable::ValidityVTableFromChild;
39use vortex_array::vtable::VisitorVTable;
40use vortex_buffer::Buffer;
41use vortex_dtype::DType;
42use vortex_dtype::Nullability;
43use vortex_dtype::PType;
44use vortex_error::VortexError;
45use vortex_error::VortexExpect;
46use vortex_error::VortexResult;
47use vortex_error::vortex_bail;
48use vortex_error::vortex_ensure;
49use vortex_error::vortex_err;
50
51use crate::alp_rd::alp_rd_decode;
52
53vtable!(ALPRD);
54
55#[derive(Clone, prost::Message)]
56pub struct ALPRDMetadata {
57    #[prost(uint32, tag = "1")]
58    right_bit_width: u32,
59    #[prost(uint32, tag = "2")]
60    dict_len: u32,
61    #[prost(uint32, repeated, tag = "3")]
62    dict: Vec<u32>,
63    #[prost(enumeration = "PType", tag = "4")]
64    left_parts_ptype: i32,
65    #[prost(message, tag = "5")]
66    patches: Option<PatchesMetadata>,
67}
68
69impl VTable for ALPRDVTable {
70    type Array = ALPRDArray;
71
72    type Metadata = ProstMetadata<ALPRDMetadata>;
73
74    type ArrayVTable = Self;
75    type CanonicalVTable = Self;
76    type OperationsVTable = Self;
77    type ValidityVTable = ValidityVTableFromChild;
78    type VisitorVTable = Self;
79    type ComputeVTable = NotSupported;
80    type EncodeVTable = Self;
81
82    fn id(&self) -> ArrayId {
83        ArrayId::new_ref("vortex.alprd")
84    }
85
86    fn encoding(_array: &Self::Array) -> ArrayVTable {
87        ALPRDVTable.as_vtable()
88    }
89
90    fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
91        let dict = array
92            .left_parts_dictionary()
93            .iter()
94            .map(|&i| i as u32)
95            .collect::<Vec<_>>();
96
97        Ok(ProstMetadata(ALPRDMetadata {
98            right_bit_width: array.right_bit_width() as u32,
99            dict_len: array.left_parts_dictionary().len() as u32,
100            dict,
101            left_parts_ptype: PType::try_from(array.left_parts().dtype())
102                .vortex_expect("Must be a valid PType") as i32,
103            patches: array
104                .left_parts_patches()
105                .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
106                .transpose()?,
107        }))
108    }
109
110    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
111        Ok(Some(metadata.serialize()))
112    }
113
114    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
115        Ok(ProstMetadata(
116            <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(buffer)?,
117        ))
118    }
119
120    fn build(
121        &self,
122        dtype: &DType,
123        len: usize,
124        metadata: &Self::Metadata,
125        _buffers: &[BufferHandle],
126        children: &dyn ArrayChildren,
127    ) -> VortexResult<ALPRDArray> {
128        if children.len() < 2 {
129            vortex_bail!(
130                "Expected at least 2 children for ALPRD encoding, found {}",
131                children.len()
132            );
133        }
134
135        let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
136        let left_parts = children.get(0, &left_parts_dtype, len)?;
137        let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
138            [0..metadata.0.dict_len as usize]
139            .iter()
140            .map(|&i| {
141                u16::try_from(i)
142                    .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
143            })
144            .try_collect()?;
145
146        let right_parts_dtype = match &dtype {
147            DType::Primitive(PType::F32, _) => {
148                DType::Primitive(PType::U32, Nullability::NonNullable)
149            }
150            DType::Primitive(PType::F64, _) => {
151                DType::Primitive(PType::U64, Nullability::NonNullable)
152            }
153            _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
154        };
155        let right_parts = children.get(1, &right_parts_dtype, len)?;
156
157        let left_parts_patches = metadata
158            .0
159            .patches
160            .map(|p| {
161                let indices = children.get(2, &p.indices_dtype(), p.len())?;
162                let values = children.get(3, &left_parts_dtype, p.len())?;
163
164                Ok::<_, VortexError>(Patches::new(
165                    len,
166                    p.offset(),
167                    indices,
168                    values,
169                    // TODO(0ax1): handle chunk offsets
170                    None,
171                ))
172            })
173            .transpose()?;
174
175        ALPRDArray::try_new(
176            dtype.clone(),
177            left_parts,
178            left_parts_dictionary,
179            right_parts,
180            u8::try_from(metadata.0.right_bit_width).map_err(|_| {
181                vortex_err!(
182                    "right_bit_width {} out of u8 range",
183                    metadata.0.right_bit_width
184                )
185            })?,
186            left_parts_patches,
187        )
188    }
189
190    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
191        // Children: left_parts, right_parts, patches (if present): indices, values
192        let patches_info = array
193            .left_parts_patches
194            .as_ref()
195            .map(|p| (p.array_len(), p.offset()));
196
197        let expected_children = if patches_info.is_some() { 4 } else { 2 };
198
199        vortex_ensure!(
200            children.len() == expected_children,
201            "ALPRDArray expects {} children, got {}",
202            expected_children,
203            children.len()
204        );
205
206        let mut children_iter = children.into_iter();
207        array.left_parts = children_iter
208            .next()
209            .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
210        array.right_parts = children_iter
211            .next()
212            .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
213
214        if let Some((array_len, offset)) = patches_info {
215            let indices = children_iter
216                .next()
217                .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
218            let values = children_iter
219                .next()
220                .ok_or_else(|| vortex_err!("Expected patch values child"))?;
221
222            array.left_parts_patches = Some(Patches::new(
223                array_len, offset, indices, values,
224                None, // chunk_offsets not currently supported for ALPRD
225            ));
226        }
227
228        Ok(())
229    }
230}
231
232#[derive(Clone, Debug)]
233pub struct ALPRDArray {
234    dtype: DType,
235    left_parts: ArrayRef,
236    left_parts_patches: Option<Patches>,
237    left_parts_dictionary: Buffer<u16>,
238    right_parts: ArrayRef,
239    right_bit_width: u8,
240    stats_set: ArrayStats,
241}
242
243#[derive(Debug)]
244pub struct ALPRDVTable;
245
246impl ALPRDArray {
247    /// Build a new `ALPRDArray` from components.
248    pub fn try_new(
249        dtype: DType,
250        left_parts: ArrayRef,
251        left_parts_dictionary: Buffer<u16>,
252        right_parts: ArrayRef,
253        right_bit_width: u8,
254        left_parts_patches: Option<Patches>,
255    ) -> VortexResult<Self> {
256        if !dtype.is_float() {
257            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
258        }
259
260        let len = left_parts.len();
261        if right_parts.len() != len {
262            vortex_bail!(
263                "left_parts (len {}) and right_parts (len {}) must be of same length",
264                len,
265                right_parts.len()
266            );
267        }
268
269        if !left_parts.dtype().is_unsigned_int() {
270            vortex_bail!("left_parts dtype must be uint");
271        }
272        // we delegate array validity to the left_parts child
273        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
274            vortex_bail!(
275                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
276                dtype,
277                left_parts.dtype()
278            );
279        }
280
281        // we enforce right_parts to be non-nullable uint
282        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
283            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
284        }
285
286        let left_parts_patches = left_parts_patches
287            .map(|patches| {
288                if !patches.values().all_valid() {
289                    vortex_bail!("patches must be all valid: {}", patches.values());
290                }
291                // TODO(ngates): assert the DType, don't cast it.
292                patches.cast_values(left_parts.dtype())
293            })
294            .transpose()?;
295
296        Ok(Self {
297            dtype,
298            left_parts,
299            left_parts_dictionary,
300            right_parts,
301            right_bit_width,
302            left_parts_patches,
303            stats_set: Default::default(),
304        })
305    }
306
307    /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead
308    /// it constructs it from parts.
309    pub(crate) unsafe fn new_unchecked(
310        dtype: DType,
311        left_parts: ArrayRef,
312        left_parts_dictionary: Buffer<u16>,
313        right_parts: ArrayRef,
314        right_bit_width: u8,
315        left_parts_patches: Option<Patches>,
316    ) -> Self {
317        Self {
318            dtype,
319            left_parts,
320            left_parts_patches,
321            left_parts_dictionary,
322            right_parts,
323            right_bit_width,
324            stats_set: Default::default(),
325        }
326    }
327
328    /// Returns true if logical type of the array values is f32.
329    ///
330    /// Returns false if the logical type of the array values is f64.
331    #[inline]
332    pub fn is_f32(&self) -> bool {
333        matches!(&self.dtype, DType::Primitive(PType::F32, _))
334    }
335
336    /// The leftmost (most significant) bits of the floating point values stored in the array.
337    ///
338    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
339    /// the metadata of this array.
340    pub fn left_parts(&self) -> &ArrayRef {
341        &self.left_parts
342    }
343
344    /// The rightmost (least significant) bits of the floating point values stored in the array.
345    pub fn right_parts(&self) -> &ArrayRef {
346        &self.right_parts
347    }
348
349    #[inline]
350    pub fn right_bit_width(&self) -> u8 {
351        self.right_bit_width
352    }
353
354    /// Patches of left-most bits.
355    pub fn left_parts_patches(&self) -> Option<&Patches> {
356        self.left_parts_patches.as_ref()
357    }
358
359    /// The dictionary that maps the codes in `left_parts` into bit patterns.
360    #[inline]
361    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
362        &self.left_parts_dictionary
363    }
364
365    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
366        self.left_parts_patches = patches;
367    }
368}
369
370impl ValidityChild<ALPRDVTable> for ALPRDVTable {
371    fn validity_child(array: &ALPRDArray) -> &ArrayRef {
372        array.left_parts()
373    }
374}
375
376impl BaseArrayVTable<ALPRDVTable> for ALPRDVTable {
377    fn len(array: &ALPRDArray) -> usize {
378        array.left_parts.len()
379    }
380
381    fn dtype(array: &ALPRDArray) -> &DType {
382        &array.dtype
383    }
384
385    fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
386        array.stats_set.to_ref(array.as_ref())
387    }
388
389    fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
390        array.dtype.hash(state);
391        array.left_parts.array_hash(state, precision);
392        array.left_parts_dictionary.array_hash(state, precision);
393        array.right_parts.array_hash(state, precision);
394        array.right_bit_width.hash(state);
395        array.left_parts_patches.array_hash(state, precision);
396    }
397
398    fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
399        array.dtype == other.dtype
400            && array.left_parts.array_eq(&other.left_parts, precision)
401            && array
402                .left_parts_dictionary
403                .array_eq(&other.left_parts_dictionary, precision)
404            && array.right_parts.array_eq(&other.right_parts, precision)
405            && array.right_bit_width == other.right_bit_width
406            && array
407                .left_parts_patches
408                .array_eq(&other.left_parts_patches, precision)
409    }
410}
411
412impl CanonicalVTable<ALPRDVTable> for ALPRDVTable {
413    fn canonicalize(array: &ALPRDArray) -> Canonical {
414        let left_parts = array.left_parts().to_primitive();
415        let right_parts = array.right_parts().to_primitive();
416
417        // Decode the left_parts using our builtin dictionary.
418        let left_parts_dict = array.left_parts_dictionary();
419
420        let decoded_array = if array.is_f32() {
421            PrimitiveArray::new(
422                alp_rd_decode::<f32>(
423                    left_parts.into_buffer::<u16>(),
424                    left_parts_dict,
425                    array.right_bit_width,
426                    right_parts.into_buffer_mut::<u32>(),
427                    array.left_parts_patches(),
428                ),
429                Validity::copy_from_array(array.as_ref()),
430            )
431        } else {
432            PrimitiveArray::new(
433                alp_rd_decode::<f64>(
434                    left_parts.into_buffer::<u16>(),
435                    left_parts_dict,
436                    array.right_bit_width,
437                    right_parts.into_buffer_mut::<u64>(),
438                    array.left_parts_patches(),
439                ),
440                Validity::copy_from_array(array.as_ref()),
441            )
442        };
443
444        Canonical::Primitive(decoded_array)
445    }
446}
447
448impl EncodeVTable<ALPRDVTable> for ALPRDVTable {
449    fn encode(
450        _vtable: &ALPRDVTable,
451        canonical: &Canonical,
452        like: Option<&ALPRDArray>,
453    ) -> VortexResult<Option<ALPRDArray>> {
454        let parray = canonical.clone().into_primitive();
455
456        let alprd_array = match like {
457            None => {
458                let encoder = match parray.ptype() {
459                    PType::F32 => crate::alp_rd::RDEncoder::new(parray.as_slice::<f32>()),
460                    PType::F64 => crate::alp_rd::RDEncoder::new(parray.as_slice::<f64>()),
461                    ptype => vortex_bail!("cannot ALPRD compress ptype {ptype}"),
462                };
463                encoder.encode(&parray)
464            }
465            Some(like) => {
466                let encoder = crate::alp_rd::RDEncoder::from_parts(
467                    like.right_bit_width(),
468                    like.left_parts_dictionary().to_vec(),
469                );
470                encoder.encode(&parray)
471            }
472        };
473
474        Ok(Some(alprd_array))
475    }
476}
477
478impl VisitorVTable<ALPRDVTable> for ALPRDVTable {
479    fn visit_buffers(_array: &ALPRDArray, _visitor: &mut dyn ArrayBufferVisitor) {}
480
481    fn visit_children(array: &ALPRDArray, visitor: &mut dyn ArrayChildVisitor) {
482        visitor.visit_child("left_parts", array.left_parts());
483        visitor.visit_child("right_parts", array.right_parts());
484        if let Some(patches) = array.left_parts_patches() {
485            visitor.visit_patches(patches);
486        }
487    }
488}
489
490#[cfg(test)]
491mod test {
492    use rstest::rstest;
493    use vortex_array::ProstMetadata;
494    use vortex_array::ToCanonical;
495    use vortex_array::arrays::PrimitiveArray;
496    use vortex_array::assert_arrays_eq;
497    use vortex_array::patches::PatchesMetadata;
498    use vortex_array::test_harness::check_metadata;
499    use vortex_dtype::PType;
500
501    use super::ALPRDMetadata;
502    use crate::ALPRDFloat;
503    use crate::alp_rd;
504
505    #[rstest]
506    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
507    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
508    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
509        #[case] reals: Vec<T>,
510        #[case] seed: T,
511    ) {
512        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
513        // Null out some of the values.
514        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
515        reals[1] = None;
516        reals[5] = None;
517        reals[900] = None;
518
519        // Create a new array from this.
520        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
521
522        // Pick a seed that we know will trigger lots of patches.
523        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
524
525        let rd_array = encoder.encode(&real_array);
526
527        let decoded = rd_array.to_primitive();
528
529        assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
530    }
531
532    #[cfg_attr(miri, ignore)]
533    #[test]
534    fn test_alprd_metadata() {
535        check_metadata(
536            "alprd.metadata",
537            ProstMetadata(ALPRDMetadata {
538                right_bit_width: u32::MAX,
539                patches: Some(PatchesMetadata::new(
540                    usize::MAX,
541                    usize::MAX,
542                    PType::U64,
543                    None,
544                    None,
545                    None,
546                )),
547                dict: Vec::new(),
548                left_parts_ptype: PType::U64 as i32,
549                dict_len: 8,
550            }),
551        );
552    }
553}