Skip to main content

vortex_alp/alp_rd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayRef;
14use vortex_array::DeserializeMetadata;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::SerializeMetadata;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::dtype::PType;
25use vortex_array::patches::Patches;
26use vortex_array::patches::PatchesMetadata;
27use vortex_array::serde::ArrayChildren;
28use vortex_array::stats::ArrayStats;
29use vortex_array::stats::StatsSetRef;
30use vortex_array::validity::Validity;
31use vortex_array::vtable;
32use vortex_array::vtable::ArrayId;
33use vortex_array::vtable::BaseArrayVTable;
34use vortex_array::vtable::VTable;
35use vortex_array::vtable::ValidityChild;
36use vortex_array::vtable::ValidityVTableFromChild;
37use vortex_array::vtable::VisitorVTable;
38use vortex_buffer::Buffer;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_error::vortex_err;
43use vortex_mask::Mask;
44use vortex_session::VortexSession;
45
46use crate::alp_rd::kernel::PARENT_KERNELS;
47use crate::alp_rd::rules::RULES;
48use crate::alp_rd_decode;
49
50vtable!(ALPRD);
51
52#[derive(Clone, prost::Message)]
53pub struct ALPRDMetadata {
54    #[prost(uint32, tag = "1")]
55    right_bit_width: u32,
56    #[prost(uint32, tag = "2")]
57    dict_len: u32,
58    #[prost(uint32, repeated, tag = "3")]
59    dict: Vec<u32>,
60    #[prost(enumeration = "PType", tag = "4")]
61    left_parts_ptype: i32,
62    #[prost(message, tag = "5")]
63    patches: Option<PatchesMetadata>,
64}
65
66impl VTable for ALPRDVTable {
67    type Array = ALPRDArray;
68
69    type Metadata = ProstMetadata<ALPRDMetadata>;
70
71    type ArrayVTable = Self;
72    type OperationsVTable = Self;
73    type ValidityVTable = ValidityVTableFromChild;
74    type VisitorVTable = Self;
75
76    fn id(_array: &Self::Array) -> ArrayId {
77        Self::ID
78    }
79
80    fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
81        let dict = array
82            .left_parts_dictionary()
83            .iter()
84            .map(|&i| i as u32)
85            .collect::<Vec<_>>();
86
87        Ok(ProstMetadata(ALPRDMetadata {
88            right_bit_width: array.right_bit_width() as u32,
89            dict_len: array.left_parts_dictionary().len() as u32,
90            dict,
91            left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
92            patches: array
93                .left_parts_patches()
94                .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
95                .transpose()?,
96        }))
97    }
98
99    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
100        Ok(Some(metadata.serialize()))
101    }
102
103    fn deserialize(
104        bytes: &[u8],
105        _dtype: &DType,
106        _len: usize,
107        _buffers: &[BufferHandle],
108        _session: &VortexSession,
109    ) -> VortexResult<Self::Metadata> {
110        Ok(ProstMetadata(
111            <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
112        ))
113    }
114
115    fn build(
116        dtype: &DType,
117        len: usize,
118        metadata: &Self::Metadata,
119        _buffers: &[BufferHandle],
120        children: &dyn ArrayChildren,
121    ) -> VortexResult<ALPRDArray> {
122        if children.len() < 2 {
123            vortex_bail!(
124                "Expected at least 2 children for ALPRD encoding, found {}",
125                children.len()
126            );
127        }
128
129        let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
130        let left_parts = children.get(0, &left_parts_dtype, len)?;
131        let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
132            [0..metadata.0.dict_len as usize]
133            .iter()
134            .map(|&i| {
135                u16::try_from(i)
136                    .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
137            })
138            .try_collect()?;
139
140        let right_parts_dtype = match &dtype {
141            DType::Primitive(PType::F32, _) => {
142                DType::Primitive(PType::U32, Nullability::NonNullable)
143            }
144            DType::Primitive(PType::F64, _) => {
145                DType::Primitive(PType::U64, Nullability::NonNullable)
146            }
147            _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
148        };
149        let right_parts = children.get(1, &right_parts_dtype, len)?;
150
151        let left_parts_patches = metadata
152            .0
153            .patches
154            .map(|p| {
155                let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
156                let values = children.get(3, &left_parts_dtype, p.len()?)?;
157
158                Patches::new(
159                    len,
160                    p.offset()?,
161                    indices,
162                    values,
163                    // TODO(0ax1): handle chunk offsets
164                    None,
165                )
166            })
167            .transpose()?;
168
169        ALPRDArray::try_new(
170            dtype.clone(),
171            left_parts,
172            left_parts_dictionary,
173            right_parts,
174            u8::try_from(metadata.0.right_bit_width).map_err(|_| {
175                vortex_err!(
176                    "right_bit_width {} out of u8 range",
177                    metadata.0.right_bit_width
178                )
179            })?,
180            left_parts_patches,
181        )
182    }
183
184    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
185        // Children: left_parts, right_parts, patches (if present): indices, values
186        let patches_info = array
187            .left_parts_patches
188            .as_ref()
189            .map(|p| (p.array_len(), p.offset()));
190
191        let expected_children = if patches_info.is_some() { 4 } else { 2 };
192
193        vortex_ensure!(
194            children.len() == expected_children,
195            "ALPRDArray expects {} children, got {}",
196            expected_children,
197            children.len()
198        );
199
200        let mut children_iter = children.into_iter();
201        array.left_parts = children_iter
202            .next()
203            .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
204        array.right_parts = children_iter
205            .next()
206            .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
207
208        if let Some((array_len, offset)) = patches_info {
209            let indices = children_iter
210                .next()
211                .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
212            let values = children_iter
213                .next()
214                .ok_or_else(|| vortex_err!("Expected patch values child"))?;
215
216            array.left_parts_patches = Some(Patches::new(
217                array_len, offset, indices, values,
218                None, // chunk_offsets not currently supported for ALPRD
219            )?);
220        }
221
222        Ok(())
223    }
224
225    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
226        let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
227        let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
228
229        // Decode the left_parts using our builtin dictionary.
230        let left_parts_dict = array.left_parts_dictionary();
231
232        let validity = array
233            .left_parts()
234            .validity()?
235            .to_array(array.len())
236            .execute::<Mask>(ctx)?;
237
238        let decoded_array = if array.is_f32() {
239            PrimitiveArray::new(
240                alp_rd_decode::<f32>(
241                    left_parts.into_buffer::<u16>(),
242                    left_parts_dict,
243                    array.right_bit_width,
244                    right_parts.into_buffer_mut::<u32>(),
245                    array.left_parts_patches(),
246                    ctx,
247                )?,
248                Validity::from_mask(validity, array.dtype().nullability()),
249            )
250        } else {
251            PrimitiveArray::new(
252                alp_rd_decode::<f64>(
253                    left_parts.into_buffer::<u16>(),
254                    left_parts_dict,
255                    array.right_bit_width,
256                    right_parts.into_buffer_mut::<u64>(),
257                    array.left_parts_patches(),
258                    ctx,
259                )?,
260                Validity::from_mask(validity, array.dtype().nullability()),
261            )
262        };
263
264        Ok(decoded_array.into_array())
265    }
266
267    fn reduce_parent(
268        array: &Self::Array,
269        parent: &ArrayRef,
270        child_idx: usize,
271    ) -> VortexResult<Option<ArrayRef>> {
272        RULES.evaluate(array, parent, child_idx)
273    }
274
275    fn execute_parent(
276        array: &Self::Array,
277        parent: &ArrayRef,
278        child_idx: usize,
279        ctx: &mut ExecutionCtx,
280    ) -> VortexResult<Option<ArrayRef>> {
281        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
282    }
283}
284
285#[derive(Clone, Debug)]
286pub struct ALPRDArray {
287    dtype: DType,
288    left_parts: ArrayRef,
289    left_parts_patches: Option<Patches>,
290    left_parts_dictionary: Buffer<u16>,
291    right_parts: ArrayRef,
292    right_bit_width: u8,
293    stats_set: ArrayStats,
294}
295
296#[derive(Debug)]
297pub struct ALPRDVTable;
298
299impl ALPRDVTable {
300    pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
301}
302
303impl ALPRDArray {
304    /// Build a new `ALPRDArray` from components.
305    pub fn try_new(
306        dtype: DType,
307        left_parts: ArrayRef,
308        left_parts_dictionary: Buffer<u16>,
309        right_parts: ArrayRef,
310        right_bit_width: u8,
311        left_parts_patches: Option<Patches>,
312    ) -> VortexResult<Self> {
313        if !dtype.is_float() {
314            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
315        }
316
317        let len = left_parts.len();
318        if right_parts.len() != len {
319            vortex_bail!(
320                "left_parts (len {}) and right_parts (len {}) must be of same length",
321                len,
322                right_parts.len()
323            );
324        }
325
326        if !left_parts.dtype().is_unsigned_int() {
327            vortex_bail!("left_parts dtype must be uint");
328        }
329        // we delegate array validity to the left_parts child
330        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
331            vortex_bail!(
332                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
333                dtype,
334                left_parts.dtype()
335            );
336        }
337
338        // we enforce right_parts to be non-nullable uint
339        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
340            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
341        }
342
343        let left_parts_patches = left_parts_patches
344            .map(|patches| {
345                if !patches.values().all_valid()? {
346                    vortex_bail!("patches must be all valid: {}", patches.values());
347                }
348                // TODO(ngates): assert the DType, don't cast it.
349                // TODO(joe): assert the DType, don't cast it in the next PR.
350                let mut patches = patches.cast_values(left_parts.dtype())?;
351                // Force execution of the lazy cast so patch values are materialized
352                // before serialization.
353                *patches.values_mut() = patches.values().to_canonical()?.into_array();
354                Ok(patches)
355            })
356            .transpose()?;
357
358        Ok(Self {
359            dtype,
360            left_parts,
361            left_parts_dictionary,
362            right_parts,
363            right_bit_width,
364            left_parts_patches,
365            stats_set: Default::default(),
366        })
367    }
368
369    /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead
370    /// it constructs it from parts.
371    pub(crate) unsafe fn new_unchecked(
372        dtype: DType,
373        left_parts: ArrayRef,
374        left_parts_dictionary: Buffer<u16>,
375        right_parts: ArrayRef,
376        right_bit_width: u8,
377        left_parts_patches: Option<Patches>,
378    ) -> Self {
379        Self {
380            dtype,
381            left_parts,
382            left_parts_patches,
383            left_parts_dictionary,
384            right_parts,
385            right_bit_width,
386            stats_set: Default::default(),
387        }
388    }
389
390    /// Returns true if logical type of the array values is f32.
391    ///
392    /// Returns false if the logical type of the array values is f64.
393    #[inline]
394    pub fn is_f32(&self) -> bool {
395        matches!(&self.dtype, DType::Primitive(PType::F32, _))
396    }
397
398    /// The leftmost (most significant) bits of the floating point values stored in the array.
399    ///
400    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
401    /// the metadata of this array.
402    pub fn left_parts(&self) -> &ArrayRef {
403        &self.left_parts
404    }
405
406    /// The rightmost (least significant) bits of the floating point values stored in the array.
407    pub fn right_parts(&self) -> &ArrayRef {
408        &self.right_parts
409    }
410
411    #[inline]
412    pub fn right_bit_width(&self) -> u8 {
413        self.right_bit_width
414    }
415
416    /// Patches of left-most bits.
417    pub fn left_parts_patches(&self) -> Option<&Patches> {
418        self.left_parts_patches.as_ref()
419    }
420
421    /// The dictionary that maps the codes in `left_parts` into bit patterns.
422    #[inline]
423    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
424        &self.left_parts_dictionary
425    }
426
427    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
428        self.left_parts_patches = patches;
429    }
430}
431
432impl ValidityChild<ALPRDVTable> for ALPRDVTable {
433    fn validity_child(array: &ALPRDArray) -> &ArrayRef {
434        array.left_parts()
435    }
436}
437
438impl BaseArrayVTable<ALPRDVTable> for ALPRDVTable {
439    fn len(array: &ALPRDArray) -> usize {
440        array.left_parts.len()
441    }
442
443    fn dtype(array: &ALPRDArray) -> &DType {
444        &array.dtype
445    }
446
447    fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
448        array.stats_set.to_ref(array.as_ref())
449    }
450
451    fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
452        array.dtype.hash(state);
453        array.left_parts.array_hash(state, precision);
454        array.left_parts_dictionary.array_hash(state, precision);
455        array.right_parts.array_hash(state, precision);
456        array.right_bit_width.hash(state);
457        array.left_parts_patches.array_hash(state, precision);
458    }
459
460    fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
461        array.dtype == other.dtype
462            && array.left_parts.array_eq(&other.left_parts, precision)
463            && array
464                .left_parts_dictionary
465                .array_eq(&other.left_parts_dictionary, precision)
466            && array.right_parts.array_eq(&other.right_parts, precision)
467            && array.right_bit_width == other.right_bit_width
468            && array
469                .left_parts_patches
470                .array_eq(&other.left_parts_patches, precision)
471    }
472}
473
474impl VisitorVTable<ALPRDVTable> for ALPRDVTable {
475    fn visit_buffers(_array: &ALPRDArray, _visitor: &mut dyn ArrayBufferVisitor) {}
476
477    fn nbuffers(_array: &ALPRDArray) -> usize {
478        0
479    }
480
481    fn visit_children(array: &ALPRDArray, visitor: &mut dyn ArrayChildVisitor) {
482        visitor.visit_child("left_parts", array.left_parts());
483        visitor.visit_child("right_parts", array.right_parts());
484        if let Some(patches) = array.left_parts_patches() {
485            visitor.visit_patches(patches);
486        }
487    }
488
489    fn nchildren(array: &ALPRDArray) -> usize {
490        // left_parts + right_parts + optional patches (indices + values)
491        2 + array
492            .left_parts_patches()
493            .map_or(0, |p| 2 + p.chunk_offsets().is_some() as usize)
494    }
495}
496
497#[cfg(test)]
498mod test {
499    use rstest::rstest;
500    use vortex_array::ProstMetadata;
501    use vortex_array::ToCanonical;
502    use vortex_array::arrays::PrimitiveArray;
503    use vortex_array::assert_arrays_eq;
504    use vortex_array::dtype::PType;
505    use vortex_array::patches::PatchesMetadata;
506    use vortex_array::test_harness::check_metadata;
507
508    use super::ALPRDMetadata;
509    use crate::ALPRDFloat;
510    use crate::alp_rd;
511
512    #[rstest]
513    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
514    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
515    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
516        #[case] reals: Vec<T>,
517        #[case] seed: T,
518    ) {
519        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
520        // Null out some of the values.
521        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
522        reals[1] = None;
523        reals[5] = None;
524        reals[900] = None;
525
526        // Create a new array from this.
527        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
528
529        // Pick a seed that we know will trigger lots of patches.
530        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
531
532        let rd_array = encoder.encode(&real_array);
533
534        let decoded = rd_array.to_primitive();
535
536        assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
537    }
538
539    #[cfg_attr(miri, ignore)]
540    #[test]
541    fn test_alprd_metadata() {
542        check_metadata(
543            "alprd.metadata",
544            ProstMetadata(ALPRDMetadata {
545                right_bit_width: u32::MAX,
546                patches: Some(PatchesMetadata::new(
547                    usize::MAX,
548                    usize::MAX,
549                    PType::U64,
550                    None,
551                    None,
552                    None,
553                )),
554                dict: Vec::new(),
555                left_parts_ptype: PType::U64 as i32,
556                dict_len: 8,
557            }),
558        );
559    }
560}