Skip to main content

vortex_alp/alp_rd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayRef;
14use vortex_array::DeserializeMetadata;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::SerializeMetadata;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::patches::Patches;
23use vortex_array::patches::PatchesMetadata;
24use vortex_array::serde::ArrayChildren;
25use vortex_array::stats::ArrayStats;
26use vortex_array::stats::StatsSetRef;
27use vortex_array::validity::Validity;
28use vortex_array::vtable;
29use vortex_array::vtable::ArrayId;
30use vortex_array::vtable::BaseArrayVTable;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityChild;
33use vortex_array::vtable::ValidityVTableFromChild;
34use vortex_array::vtable::VisitorVTable;
35use vortex_buffer::Buffer;
36use vortex_dtype::DType;
37use vortex_dtype::Nullability;
38use vortex_dtype::PType;
39use vortex_error::VortexExpect;
40use vortex_error::VortexResult;
41use vortex_error::vortex_bail;
42use vortex_error::vortex_ensure;
43use vortex_error::vortex_err;
44use vortex_mask::Mask;
45use vortex_session::VortexSession;
46
47use crate::alp_rd::kernel::PARENT_KERNELS;
48use crate::alp_rd::rules::RULES;
49use crate::alp_rd_decode;
50
51vtable!(ALPRD);
52
53#[derive(Clone, prost::Message)]
54pub struct ALPRDMetadata {
55    #[prost(uint32, tag = "1")]
56    right_bit_width: u32,
57    #[prost(uint32, tag = "2")]
58    dict_len: u32,
59    #[prost(uint32, repeated, tag = "3")]
60    dict: Vec<u32>,
61    #[prost(enumeration = "PType", tag = "4")]
62    left_parts_ptype: i32,
63    #[prost(message, tag = "5")]
64    patches: Option<PatchesMetadata>,
65}
66
67impl VTable for ALPRDVTable {
68    type Array = ALPRDArray;
69
70    type Metadata = ProstMetadata<ALPRDMetadata>;
71
72    type ArrayVTable = Self;
73    type OperationsVTable = Self;
74    type ValidityVTable = ValidityVTableFromChild;
75    type VisitorVTable = Self;
76
77    fn id(_array: &Self::Array) -> ArrayId {
78        Self::ID
79    }
80
81    fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
82        let dict = array
83            .left_parts_dictionary()
84            .iter()
85            .map(|&i| i as u32)
86            .collect::<Vec<_>>();
87
88        Ok(ProstMetadata(ALPRDMetadata {
89            right_bit_width: array.right_bit_width() as u32,
90            dict_len: array.left_parts_dictionary().len() as u32,
91            dict,
92            left_parts_ptype: PType::try_from(array.left_parts().dtype())
93                .vortex_expect("Must be a valid PType") as i32,
94            patches: array
95                .left_parts_patches()
96                .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
97                .transpose()?,
98        }))
99    }
100
101    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
102        Ok(Some(metadata.serialize()))
103    }
104
105    fn deserialize(
106        bytes: &[u8],
107        _dtype: &DType,
108        _len: usize,
109        _buffers: &[BufferHandle],
110        _session: &VortexSession,
111    ) -> VortexResult<Self::Metadata> {
112        Ok(ProstMetadata(
113            <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
114        ))
115    }
116
117    fn build(
118        dtype: &DType,
119        len: usize,
120        metadata: &Self::Metadata,
121        _buffers: &[BufferHandle],
122        children: &dyn ArrayChildren,
123    ) -> VortexResult<ALPRDArray> {
124        if children.len() < 2 {
125            vortex_bail!(
126                "Expected at least 2 children for ALPRD encoding, found {}",
127                children.len()
128            );
129        }
130
131        let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
132        let left_parts = children.get(0, &left_parts_dtype, len)?;
133        let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
134            [0..metadata.0.dict_len as usize]
135            .iter()
136            .map(|&i| {
137                u16::try_from(i)
138                    .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
139            })
140            .try_collect()?;
141
142        let right_parts_dtype = match &dtype {
143            DType::Primitive(PType::F32, _) => {
144                DType::Primitive(PType::U32, Nullability::NonNullable)
145            }
146            DType::Primitive(PType::F64, _) => {
147                DType::Primitive(PType::U64, Nullability::NonNullable)
148            }
149            _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
150        };
151        let right_parts = children.get(1, &right_parts_dtype, len)?;
152
153        let left_parts_patches = metadata
154            .0
155            .patches
156            .map(|p| {
157                let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
158                let values = children.get(3, &left_parts_dtype, p.len()?)?;
159
160                Patches::new(
161                    len,
162                    p.offset()?,
163                    indices,
164                    values,
165                    // TODO(0ax1): handle chunk offsets
166                    None,
167                )
168            })
169            .transpose()?;
170
171        ALPRDArray::try_new(
172            dtype.clone(),
173            left_parts,
174            left_parts_dictionary,
175            right_parts,
176            u8::try_from(metadata.0.right_bit_width).map_err(|_| {
177                vortex_err!(
178                    "right_bit_width {} out of u8 range",
179                    metadata.0.right_bit_width
180                )
181            })?,
182            left_parts_patches,
183        )
184    }
185
186    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
187        // Children: left_parts, right_parts, patches (if present): indices, values
188        let patches_info = array
189            .left_parts_patches
190            .as_ref()
191            .map(|p| (p.array_len(), p.offset()));
192
193        let expected_children = if patches_info.is_some() { 4 } else { 2 };
194
195        vortex_ensure!(
196            children.len() == expected_children,
197            "ALPRDArray expects {} children, got {}",
198            expected_children,
199            children.len()
200        );
201
202        let mut children_iter = children.into_iter();
203        array.left_parts = children_iter
204            .next()
205            .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
206        array.right_parts = children_iter
207            .next()
208            .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
209
210        if let Some((array_len, offset)) = patches_info {
211            let indices = children_iter
212                .next()
213                .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
214            let values = children_iter
215                .next()
216                .ok_or_else(|| vortex_err!("Expected patch values child"))?;
217
218            array.left_parts_patches = Some(Patches::new(
219                array_len, offset, indices, values,
220                None, // chunk_offsets not currently supported for ALPRD
221            )?);
222        }
223
224        Ok(())
225    }
226
227    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
228        let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
229        let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
230
231        // Decode the left_parts using our builtin dictionary.
232        let left_parts_dict = array.left_parts_dictionary();
233
234        let validity = array
235            .left_parts()
236            .validity()?
237            .to_array(array.len())
238            .execute::<Mask>(ctx)?;
239
240        let decoded_array = if array.is_f32() {
241            PrimitiveArray::new(
242                alp_rd_decode::<f32>(
243                    left_parts.into_buffer::<u16>(),
244                    left_parts_dict,
245                    array.right_bit_width,
246                    right_parts.into_buffer_mut::<u32>(),
247                    array.left_parts_patches(),
248                    ctx,
249                )?,
250                Validity::from_mask(validity, array.dtype().nullability()),
251            )
252        } else {
253            PrimitiveArray::new(
254                alp_rd_decode::<f64>(
255                    left_parts.into_buffer::<u16>(),
256                    left_parts_dict,
257                    array.right_bit_width,
258                    right_parts.into_buffer_mut::<u64>(),
259                    array.left_parts_patches(),
260                    ctx,
261                )?,
262                Validity::from_mask(validity, array.dtype().nullability()),
263            )
264        };
265
266        Ok(decoded_array.into_array())
267    }
268
269    fn reduce_parent(
270        array: &Self::Array,
271        parent: &ArrayRef,
272        child_idx: usize,
273    ) -> VortexResult<Option<ArrayRef>> {
274        RULES.evaluate(array, parent, child_idx)
275    }
276
277    fn execute_parent(
278        array: &Self::Array,
279        parent: &ArrayRef,
280        child_idx: usize,
281        ctx: &mut ExecutionCtx,
282    ) -> VortexResult<Option<ArrayRef>> {
283        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
284    }
285}
286
287#[derive(Clone, Debug)]
288pub struct ALPRDArray {
289    dtype: DType,
290    left_parts: ArrayRef,
291    left_parts_patches: Option<Patches>,
292    left_parts_dictionary: Buffer<u16>,
293    right_parts: ArrayRef,
294    right_bit_width: u8,
295    stats_set: ArrayStats,
296}
297
298#[derive(Debug)]
299pub struct ALPRDVTable;
300
301impl ALPRDVTable {
302    pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
303}
304
305impl ALPRDArray {
306    /// Build a new `ALPRDArray` from components.
307    pub fn try_new(
308        dtype: DType,
309        left_parts: ArrayRef,
310        left_parts_dictionary: Buffer<u16>,
311        right_parts: ArrayRef,
312        right_bit_width: u8,
313        left_parts_patches: Option<Patches>,
314    ) -> VortexResult<Self> {
315        if !dtype.is_float() {
316            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
317        }
318
319        let len = left_parts.len();
320        if right_parts.len() != len {
321            vortex_bail!(
322                "left_parts (len {}) and right_parts (len {}) must be of same length",
323                len,
324                right_parts.len()
325            );
326        }
327
328        if !left_parts.dtype().is_unsigned_int() {
329            vortex_bail!("left_parts dtype must be uint");
330        }
331        // we delegate array validity to the left_parts child
332        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
333            vortex_bail!(
334                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
335                dtype,
336                left_parts.dtype()
337            );
338        }
339
340        // we enforce right_parts to be non-nullable uint
341        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
342            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
343        }
344
345        let left_parts_patches = left_parts_patches
346            .map(|patches| {
347                if !patches.values().all_valid()? {
348                    vortex_bail!("patches must be all valid: {}", patches.values());
349                }
350                // TODO(ngates): assert the DType, don't cast it.
351                // TODO(joe): assert the DType, don't cast it in the next PR.
352                let mut patches = patches.cast_values(left_parts.dtype())?;
353                // Force execution of the lazy cast so patch values are materialized
354                // before serialization.
355                *patches.values_mut() = patches.values().to_canonical()?.into_array();
356                Ok(patches)
357            })
358            .transpose()?;
359
360        Ok(Self {
361            dtype,
362            left_parts,
363            left_parts_dictionary,
364            right_parts,
365            right_bit_width,
366            left_parts_patches,
367            stats_set: Default::default(),
368        })
369    }
370
371    /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead
372    /// it constructs it from parts.
373    pub(crate) unsafe fn new_unchecked(
374        dtype: DType,
375        left_parts: ArrayRef,
376        left_parts_dictionary: Buffer<u16>,
377        right_parts: ArrayRef,
378        right_bit_width: u8,
379        left_parts_patches: Option<Patches>,
380    ) -> Self {
381        Self {
382            dtype,
383            left_parts,
384            left_parts_patches,
385            left_parts_dictionary,
386            right_parts,
387            right_bit_width,
388            stats_set: Default::default(),
389        }
390    }
391
392    /// Returns true if logical type of the array values is f32.
393    ///
394    /// Returns false if the logical type of the array values is f64.
395    #[inline]
396    pub fn is_f32(&self) -> bool {
397        matches!(&self.dtype, DType::Primitive(PType::F32, _))
398    }
399
400    /// The leftmost (most significant) bits of the floating point values stored in the array.
401    ///
402    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
403    /// the metadata of this array.
404    pub fn left_parts(&self) -> &ArrayRef {
405        &self.left_parts
406    }
407
408    /// The rightmost (least significant) bits of the floating point values stored in the array.
409    pub fn right_parts(&self) -> &ArrayRef {
410        &self.right_parts
411    }
412
413    #[inline]
414    pub fn right_bit_width(&self) -> u8 {
415        self.right_bit_width
416    }
417
418    /// Patches of left-most bits.
419    pub fn left_parts_patches(&self) -> Option<&Patches> {
420        self.left_parts_patches.as_ref()
421    }
422
423    /// The dictionary that maps the codes in `left_parts` into bit patterns.
424    #[inline]
425    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
426        &self.left_parts_dictionary
427    }
428
429    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
430        self.left_parts_patches = patches;
431    }
432}
433
434impl ValidityChild<ALPRDVTable> for ALPRDVTable {
435    fn validity_child(array: &ALPRDArray) -> &ArrayRef {
436        array.left_parts()
437    }
438}
439
440impl BaseArrayVTable<ALPRDVTable> for ALPRDVTable {
441    fn len(array: &ALPRDArray) -> usize {
442        array.left_parts.len()
443    }
444
445    fn dtype(array: &ALPRDArray) -> &DType {
446        &array.dtype
447    }
448
449    fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
450        array.stats_set.to_ref(array.as_ref())
451    }
452
453    fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
454        array.dtype.hash(state);
455        array.left_parts.array_hash(state, precision);
456        array.left_parts_dictionary.array_hash(state, precision);
457        array.right_parts.array_hash(state, precision);
458        array.right_bit_width.hash(state);
459        array.left_parts_patches.array_hash(state, precision);
460    }
461
462    fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
463        array.dtype == other.dtype
464            && array.left_parts.array_eq(&other.left_parts, precision)
465            && array
466                .left_parts_dictionary
467                .array_eq(&other.left_parts_dictionary, precision)
468            && array.right_parts.array_eq(&other.right_parts, precision)
469            && array.right_bit_width == other.right_bit_width
470            && array
471                .left_parts_patches
472                .array_eq(&other.left_parts_patches, precision)
473    }
474}
475
476impl VisitorVTable<ALPRDVTable> for ALPRDVTable {
477    fn visit_buffers(_array: &ALPRDArray, _visitor: &mut dyn ArrayBufferVisitor) {}
478
479    fn nbuffers(_array: &ALPRDArray) -> usize {
480        0
481    }
482
483    fn visit_children(array: &ALPRDArray, visitor: &mut dyn ArrayChildVisitor) {
484        visitor.visit_child("left_parts", array.left_parts());
485        visitor.visit_child("right_parts", array.right_parts());
486        if let Some(patches) = array.left_parts_patches() {
487            visitor.visit_patches(patches);
488        }
489    }
490
491    fn nchildren(array: &ALPRDArray) -> usize {
492        // left_parts + right_parts + optional patches (indices + values)
493        2 + array
494            .left_parts_patches()
495            .map_or(0, |p| 2 + p.chunk_offsets().is_some() as usize)
496    }
497}
498
499#[cfg(test)]
500mod test {
501    use rstest::rstest;
502    use vortex_array::ProstMetadata;
503    use vortex_array::ToCanonical;
504    use vortex_array::arrays::PrimitiveArray;
505    use vortex_array::assert_arrays_eq;
506    use vortex_array::patches::PatchesMetadata;
507    use vortex_array::test_harness::check_metadata;
508    use vortex_dtype::PType;
509
510    use super::ALPRDMetadata;
511    use crate::ALPRDFloat;
512    use crate::alp_rd;
513
514    #[rstest]
515    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
516    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
517    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
518        #[case] reals: Vec<T>,
519        #[case] seed: T,
520    ) {
521        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
522        // Null out some of the values.
523        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
524        reals[1] = None;
525        reals[5] = None;
526        reals[900] = None;
527
528        // Create a new array from this.
529        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
530
531        // Pick a seed that we know will trigger lots of patches.
532        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
533
534        let rd_array = encoder.encode(&real_array);
535
536        let decoded = rd_array.to_primitive();
537
538        assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
539    }
540
541    #[cfg_attr(miri, ignore)]
542    #[test]
543    fn test_alprd_metadata() {
544        check_metadata(
545            "alprd.metadata",
546            ProstMetadata(ALPRDMetadata {
547                right_bit_width: u32::MAX,
548                patches: Some(PatchesMetadata::new(
549                    usize::MAX,
550                    usize::MAX,
551                    PType::U64,
552                    None,
553                    None,
554                    None,
555                )),
556                dict: Vec::new(),
557                left_parts_ptype: PType::U64 as i32,
558                dict_len: 8,
559            }),
560        );
561    }
562}