Skip to main content

vortex_alp/alp_rd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::DeserializeMetadata;
13use vortex_array::ExecutionCtx;
14use vortex_array::IntoArray;
15use vortex_array::Precision;
16use vortex_array::ProstMetadata;
17use vortex_array::SerializeMetadata;
18use vortex_array::arrays::PrimitiveArray;
19use vortex_array::buffer::BufferHandle;
20use vortex_array::dtype::DType;
21use vortex_array::dtype::Nullability;
22use vortex_array::dtype::PType;
23use vortex_array::patches::Patches;
24use vortex_array::patches::PatchesMetadata;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::stats::ArrayStats;
27use vortex_array::stats::StatsSetRef;
28use vortex_array::validity::Validity;
29use vortex_array::vtable;
30use vortex_array::vtable::ArrayId;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityChild;
33use vortex_array::vtable::ValidityVTableFromChild;
34use vortex_array::vtable::patches_child;
35use vortex_array::vtable::patches_child_name;
36use vortex_array::vtable::patches_nchildren;
37use vortex_buffer::Buffer;
38use vortex_error::VortexResult;
39use vortex_error::vortex_bail;
40use vortex_error::vortex_ensure;
41use vortex_error::vortex_err;
42use vortex_error::vortex_panic;
43use vortex_mask::Mask;
44use vortex_session::VortexSession;
45
46use crate::alp_rd::kernel::PARENT_KERNELS;
47use crate::alp_rd::rules::RULES;
48use crate::alp_rd_decode;
49
50vtable!(ALPRD);
51
52#[derive(Clone, prost::Message)]
53pub struct ALPRDMetadata {
54    #[prost(uint32, tag = "1")]
55    right_bit_width: u32,
56    #[prost(uint32, tag = "2")]
57    dict_len: u32,
58    #[prost(uint32, repeated, tag = "3")]
59    dict: Vec<u32>,
60    #[prost(enumeration = "PType", tag = "4")]
61    left_parts_ptype: i32,
62    #[prost(message, tag = "5")]
63    patches: Option<PatchesMetadata>,
64}
65
66impl VTable for ALPRDVTable {
67    type Array = ALPRDArray;
68
69    type Metadata = ProstMetadata<ALPRDMetadata>;
70    type OperationsVTable = Self;
71    type ValidityVTable = ValidityVTableFromChild;
72
73    fn id(_array: &Self::Array) -> ArrayId {
74        Self::ID
75    }
76
77    fn len(array: &ALPRDArray) -> usize {
78        array.left_parts.len()
79    }
80
81    fn dtype(array: &ALPRDArray) -> &DType {
82        &array.dtype
83    }
84
85    fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
86        array.stats_set.to_ref(array.as_ref())
87    }
88
89    fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
90        array.dtype.hash(state);
91        array.left_parts.array_hash(state, precision);
92        array.left_parts_dictionary.array_hash(state, precision);
93        array.right_parts.array_hash(state, precision);
94        array.right_bit_width.hash(state);
95        array.left_parts_patches.array_hash(state, precision);
96    }
97
98    fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
99        array.dtype == other.dtype
100            && array.left_parts.array_eq(&other.left_parts, precision)
101            && array
102                .left_parts_dictionary
103                .array_eq(&other.left_parts_dictionary, precision)
104            && array.right_parts.array_eq(&other.right_parts, precision)
105            && array.right_bit_width == other.right_bit_width
106            && array
107                .left_parts_patches
108                .array_eq(&other.left_parts_patches, precision)
109    }
110
111    fn nbuffers(_array: &ALPRDArray) -> usize {
112        0
113    }
114
115    fn buffer(_array: &ALPRDArray, idx: usize) -> BufferHandle {
116        vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
117    }
118
119    fn buffer_name(_array: &ALPRDArray, _idx: usize) -> Option<String> {
120        None
121    }
122
123    fn nchildren(array: &ALPRDArray) -> usize {
124        2 + array.left_parts_patches().map_or(0, patches_nchildren)
125    }
126
127    fn child(array: &ALPRDArray, idx: usize) -> ArrayRef {
128        match idx {
129            0 => array.left_parts().clone(),
130            1 => array.right_parts().clone(),
131            _ => {
132                let patches = array
133                    .left_parts_patches()
134                    .unwrap_or_else(|| vortex_panic!("ALPRDArray child index {idx} out of bounds"));
135                patches_child(patches, idx - 2)
136            }
137        }
138    }
139
140    fn child_name(array: &ALPRDArray, idx: usize) -> String {
141        match idx {
142            0 => "left_parts".to_string(),
143            1 => "right_parts".to_string(),
144            _ => {
145                if array.left_parts_patches().is_none() {
146                    vortex_panic!("ALPRDArray child_name index {idx} out of bounds");
147                }
148                patches_child_name(idx - 2).to_string()
149            }
150        }
151    }
152
153    fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
154        let dict = array
155            .left_parts_dictionary()
156            .iter()
157            .map(|&i| i as u32)
158            .collect::<Vec<_>>();
159
160        Ok(ProstMetadata(ALPRDMetadata {
161            right_bit_width: array.right_bit_width() as u32,
162            dict_len: array.left_parts_dictionary().len() as u32,
163            dict,
164            left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
165            patches: array
166                .left_parts_patches()
167                .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
168                .transpose()?,
169        }))
170    }
171
172    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
173        Ok(Some(metadata.serialize()))
174    }
175
176    fn deserialize(
177        bytes: &[u8],
178        _dtype: &DType,
179        _len: usize,
180        _buffers: &[BufferHandle],
181        _session: &VortexSession,
182    ) -> VortexResult<Self::Metadata> {
183        Ok(ProstMetadata(
184            <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
185        ))
186    }
187
188    fn build(
189        dtype: &DType,
190        len: usize,
191        metadata: &Self::Metadata,
192        _buffers: &[BufferHandle],
193        children: &dyn ArrayChildren,
194    ) -> VortexResult<ALPRDArray> {
195        if children.len() < 2 {
196            vortex_bail!(
197                "Expected at least 2 children for ALPRD encoding, found {}",
198                children.len()
199            );
200        }
201
202        let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
203        let left_parts = children.get(0, &left_parts_dtype, len)?;
204        let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
205            [0..metadata.0.dict_len as usize]
206            .iter()
207            .map(|&i| {
208                u16::try_from(i)
209                    .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
210            })
211            .try_collect()?;
212
213        let right_parts_dtype = match &dtype {
214            DType::Primitive(PType::F32, _) => {
215                DType::Primitive(PType::U32, Nullability::NonNullable)
216            }
217            DType::Primitive(PType::F64, _) => {
218                DType::Primitive(PType::U64, Nullability::NonNullable)
219            }
220            _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
221        };
222        let right_parts = children.get(1, &right_parts_dtype, len)?;
223
224        let left_parts_patches = metadata
225            .0
226            .patches
227            .map(|p| {
228                let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
229                let values = children.get(3, &left_parts_dtype, p.len()?)?;
230
231                Patches::new(
232                    len,
233                    p.offset()?,
234                    indices,
235                    values,
236                    // TODO(0ax1): handle chunk offsets
237                    None,
238                )
239            })
240            .transpose()?;
241
242        ALPRDArray::try_new(
243            dtype.clone(),
244            left_parts,
245            left_parts_dictionary,
246            right_parts,
247            u8::try_from(metadata.0.right_bit_width).map_err(|_| {
248                vortex_err!(
249                    "right_bit_width {} out of u8 range",
250                    metadata.0.right_bit_width
251                )
252            })?,
253            left_parts_patches,
254        )
255    }
256
257    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
258        // Children: left_parts, right_parts, patches (if present): indices, values
259        let patches_info = array
260            .left_parts_patches
261            .as_ref()
262            .map(|p| (p.array_len(), p.offset()));
263
264        let expected_children = if patches_info.is_some() { 4 } else { 2 };
265
266        vortex_ensure!(
267            children.len() == expected_children,
268            "ALPRDArray expects {} children, got {}",
269            expected_children,
270            children.len()
271        );
272
273        let mut children_iter = children.into_iter();
274        array.left_parts = children_iter
275            .next()
276            .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
277        array.right_parts = children_iter
278            .next()
279            .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
280
281        if let Some((array_len, offset)) = patches_info {
282            let indices = children_iter
283                .next()
284                .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
285            let values = children_iter
286                .next()
287                .ok_or_else(|| vortex_err!("Expected patch values child"))?;
288
289            array.left_parts_patches = Some(Patches::new(
290                array_len, offset, indices, values,
291                None, // chunk_offsets not currently supported for ALPRD
292            )?);
293        }
294
295        Ok(())
296    }
297
298    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
299        let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
300        let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
301
302        // Decode the left_parts using our builtin dictionary.
303        let left_parts_dict = array.left_parts_dictionary();
304
305        let validity = array
306            .left_parts()
307            .validity()?
308            .to_array(array.len())
309            .execute::<Mask>(ctx)?;
310
311        let decoded_array = if array.is_f32() {
312            PrimitiveArray::new(
313                alp_rd_decode::<f32>(
314                    left_parts.into_buffer::<u16>(),
315                    left_parts_dict,
316                    array.right_bit_width,
317                    right_parts.into_buffer_mut::<u32>(),
318                    array.left_parts_patches(),
319                    ctx,
320                )?,
321                Validity::from_mask(validity, array.dtype().nullability()),
322            )
323        } else {
324            PrimitiveArray::new(
325                alp_rd_decode::<f64>(
326                    left_parts.into_buffer::<u16>(),
327                    left_parts_dict,
328                    array.right_bit_width,
329                    right_parts.into_buffer_mut::<u64>(),
330                    array.left_parts_patches(),
331                    ctx,
332                )?,
333                Validity::from_mask(validity, array.dtype().nullability()),
334            )
335        };
336
337        Ok(decoded_array.into_array())
338    }
339
340    fn reduce_parent(
341        array: &Self::Array,
342        parent: &ArrayRef,
343        child_idx: usize,
344    ) -> VortexResult<Option<ArrayRef>> {
345        RULES.evaluate(array, parent, child_idx)
346    }
347
348    fn execute_parent(
349        array: &Self::Array,
350        parent: &ArrayRef,
351        child_idx: usize,
352        ctx: &mut ExecutionCtx,
353    ) -> VortexResult<Option<ArrayRef>> {
354        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
355    }
356}
357
358#[derive(Clone, Debug)]
359pub struct ALPRDArray {
360    dtype: DType,
361    left_parts: ArrayRef,
362    left_parts_patches: Option<Patches>,
363    left_parts_dictionary: Buffer<u16>,
364    right_parts: ArrayRef,
365    right_bit_width: u8,
366    stats_set: ArrayStats,
367}
368
369#[derive(Debug)]
370pub struct ALPRDVTable;
371
372impl ALPRDVTable {
373    pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
374}
375
376impl ALPRDArray {
377    /// Build a new `ALPRDArray` from components.
378    pub fn try_new(
379        dtype: DType,
380        left_parts: ArrayRef,
381        left_parts_dictionary: Buffer<u16>,
382        right_parts: ArrayRef,
383        right_bit_width: u8,
384        left_parts_patches: Option<Patches>,
385    ) -> VortexResult<Self> {
386        if !dtype.is_float() {
387            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
388        }
389
390        let len = left_parts.len();
391        if right_parts.len() != len {
392            vortex_bail!(
393                "left_parts (len {}) and right_parts (len {}) must be of same length",
394                len,
395                right_parts.len()
396            );
397        }
398
399        if !left_parts.dtype().is_unsigned_int() {
400            vortex_bail!("left_parts dtype must be uint");
401        }
402        // we delegate array validity to the left_parts child
403        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
404            vortex_bail!(
405                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
406                dtype,
407                left_parts.dtype()
408            );
409        }
410
411        // we enforce right_parts to be non-nullable uint
412        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
413            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
414        }
415
416        let left_parts_patches = left_parts_patches
417            .map(|patches| {
418                if !patches.values().all_valid()? {
419                    vortex_bail!("patches must be all valid: {}", patches.values());
420                }
421                // TODO(ngates): assert the DType, don't cast it.
422                // TODO(joe): assert the DType, don't cast it in the next PR.
423                let mut patches = patches.cast_values(left_parts.dtype())?;
424                // Force execution of the lazy cast so patch values are materialized
425                // before serialization.
426                *patches.values_mut() = patches.values().to_canonical()?.into_array();
427                Ok(patches)
428            })
429            .transpose()?;
430
431        Ok(Self {
432            dtype,
433            left_parts,
434            left_parts_dictionary,
435            right_parts,
436            right_bit_width,
437            left_parts_patches,
438            stats_set: Default::default(),
439        })
440    }
441
442    /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead
443    /// it constructs it from parts.
444    pub(crate) unsafe fn new_unchecked(
445        dtype: DType,
446        left_parts: ArrayRef,
447        left_parts_dictionary: Buffer<u16>,
448        right_parts: ArrayRef,
449        right_bit_width: u8,
450        left_parts_patches: Option<Patches>,
451    ) -> Self {
452        Self {
453            dtype,
454            left_parts,
455            left_parts_patches,
456            left_parts_dictionary,
457            right_parts,
458            right_bit_width,
459            stats_set: Default::default(),
460        }
461    }
462
463    /// Returns true if logical type of the array values is f32.
464    ///
465    /// Returns false if the logical type of the array values is f64.
466    #[inline]
467    pub fn is_f32(&self) -> bool {
468        matches!(&self.dtype, DType::Primitive(PType::F32, _))
469    }
470
471    /// The leftmost (most significant) bits of the floating point values stored in the array.
472    ///
473    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
474    /// the metadata of this array.
475    pub fn left_parts(&self) -> &ArrayRef {
476        &self.left_parts
477    }
478
479    /// The rightmost (least significant) bits of the floating point values stored in the array.
480    pub fn right_parts(&self) -> &ArrayRef {
481        &self.right_parts
482    }
483
484    #[inline]
485    pub fn right_bit_width(&self) -> u8 {
486        self.right_bit_width
487    }
488
489    /// Patches of left-most bits.
490    pub fn left_parts_patches(&self) -> Option<&Patches> {
491        self.left_parts_patches.as_ref()
492    }
493
494    /// The dictionary that maps the codes in `left_parts` into bit patterns.
495    #[inline]
496    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
497        &self.left_parts_dictionary
498    }
499
500    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
501        self.left_parts_patches = patches;
502    }
503}
504
505impl ValidityChild<ALPRDVTable> for ALPRDVTable {
506    fn validity_child(array: &ALPRDArray) -> &ArrayRef {
507        array.left_parts()
508    }
509}
510
511#[cfg(test)]
512mod test {
513    use rstest::rstest;
514    use vortex_array::ProstMetadata;
515    use vortex_array::ToCanonical;
516    use vortex_array::arrays::PrimitiveArray;
517    use vortex_array::assert_arrays_eq;
518    use vortex_array::dtype::PType;
519    use vortex_array::patches::PatchesMetadata;
520    use vortex_array::test_harness::check_metadata;
521
522    use super::ALPRDMetadata;
523    use crate::ALPRDFloat;
524    use crate::alp_rd;
525
526    #[rstest]
527    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
528    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
529    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
530        #[case] reals: Vec<T>,
531        #[case] seed: T,
532    ) {
533        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
534        // Null out some of the values.
535        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
536        reals[1] = None;
537        reals[5] = None;
538        reals[900] = None;
539
540        // Create a new array from this.
541        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
542
543        // Pick a seed that we know will trigger lots of patches.
544        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
545
546        let rd_array = encoder.encode(&real_array);
547
548        let decoded = rd_array.to_primitive();
549
550        assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
551    }
552
553    #[cfg_attr(miri, ignore)]
554    #[test]
555    fn test_alprd_metadata() {
556        check_metadata(
557            "alprd.metadata",
558            ProstMetadata(ALPRDMetadata {
559                right_bit_width: u32::MAX,
560                patches: Some(PatchesMetadata::new(
561                    usize::MAX,
562                    usize::MAX,
563                    PType::U64,
564                    None,
565                    None,
566                    None,
567                )),
568                dict: Vec::new(),
569                left_parts_ptype: PType::U64 as i32,
570                dict_len: 8,
571            }),
572        );
573    }
574}