Skip to main content

vortex_alp/alp_rd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::ArrayEq;
9use vortex_array::ArrayHash;
10use vortex_array::ArrayRef;
11use vortex_array::DeserializeMetadata;
12use vortex_array::DynArray;
13use vortex_array::ExecutionCtx;
14use vortex_array::ExecutionStep;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::ProstMetadata;
18use vortex_array::SerializeMetadata;
19use vortex_array::arrays::PrimitiveArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::dtype::DType;
22use vortex_array::dtype::Nullability;
23use vortex_array::dtype::PType;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::serde::ArrayChildren;
27use vortex_array::stats::ArrayStats;
28use vortex_array::stats::StatsSetRef;
29use vortex_array::validity::Validity;
30use vortex_array::vtable;
31use vortex_array::vtable::ArrayId;
32use vortex_array::vtable::VTable;
33use vortex_array::vtable::ValidityChild;
34use vortex_array::vtable::ValidityVTableFromChild;
35use vortex_array::vtable::patches_child;
36use vortex_array::vtable::patches_child_name;
37use vortex_array::vtable::patches_nchildren;
38use vortex_buffer::Buffer;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_error::vortex_err;
43use vortex_error::vortex_panic;
44use vortex_mask::Mask;
45use vortex_session::VortexSession;
46
47use crate::alp_rd::kernel::PARENT_KERNELS;
48use crate::alp_rd::rules::RULES;
49use crate::alp_rd_decode;
50
51vtable!(ALPRD);
52
53#[derive(Clone, prost::Message)]
54pub struct ALPRDMetadata {
55    #[prost(uint32, tag = "1")]
56    right_bit_width: u32,
57    #[prost(uint32, tag = "2")]
58    dict_len: u32,
59    #[prost(uint32, repeated, tag = "3")]
60    dict: Vec<u32>,
61    #[prost(enumeration = "PType", tag = "4")]
62    left_parts_ptype: i32,
63    #[prost(message, tag = "5")]
64    patches: Option<PatchesMetadata>,
65}
66
67impl VTable for ALPRDVTable {
68    type Array = ALPRDArray;
69
70    type Metadata = ProstMetadata<ALPRDMetadata>;
71    type OperationsVTable = Self;
72    type ValidityVTable = ValidityVTableFromChild;
73
74    fn id(_array: &Self::Array) -> ArrayId {
75        Self::ID
76    }
77
78    fn len(array: &ALPRDArray) -> usize {
79        array.left_parts.len()
80    }
81
82    fn dtype(array: &ALPRDArray) -> &DType {
83        &array.dtype
84    }
85
86    fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
87        array.stats_set.to_ref(array.as_ref())
88    }
89
90    fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
91        array.dtype.hash(state);
92        array.left_parts.array_hash(state, precision);
93        array.left_parts_dictionary.array_hash(state, precision);
94        array.right_parts.array_hash(state, precision);
95        array.right_bit_width.hash(state);
96        array.left_parts_patches.array_hash(state, precision);
97    }
98
99    fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
100        array.dtype == other.dtype
101            && array.left_parts.array_eq(&other.left_parts, precision)
102            && array
103                .left_parts_dictionary
104                .array_eq(&other.left_parts_dictionary, precision)
105            && array.right_parts.array_eq(&other.right_parts, precision)
106            && array.right_bit_width == other.right_bit_width
107            && array
108                .left_parts_patches
109                .array_eq(&other.left_parts_patches, precision)
110    }
111
112    fn nbuffers(_array: &ALPRDArray) -> usize {
113        0
114    }
115
116    fn buffer(_array: &ALPRDArray, idx: usize) -> BufferHandle {
117        vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
118    }
119
120    fn buffer_name(_array: &ALPRDArray, _idx: usize) -> Option<String> {
121        None
122    }
123
124    fn nchildren(array: &ALPRDArray) -> usize {
125        2 + array.left_parts_patches().map_or(0, patches_nchildren)
126    }
127
128    fn child(array: &ALPRDArray, idx: usize) -> ArrayRef {
129        match idx {
130            0 => array.left_parts().clone(),
131            1 => array.right_parts().clone(),
132            _ => {
133                let patches = array
134                    .left_parts_patches()
135                    .unwrap_or_else(|| vortex_panic!("ALPRDArray child index {idx} out of bounds"));
136                patches_child(patches, idx - 2)
137            }
138        }
139    }
140
141    fn child_name(array: &ALPRDArray, idx: usize) -> String {
142        match idx {
143            0 => "left_parts".to_string(),
144            1 => "right_parts".to_string(),
145            _ => {
146                if array.left_parts_patches().is_none() {
147                    vortex_panic!("ALPRDArray child_name index {idx} out of bounds");
148                }
149                patches_child_name(idx - 2).to_string()
150            }
151        }
152    }
153
154    fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
155        let dict = array
156            .left_parts_dictionary()
157            .iter()
158            .map(|&i| i as u32)
159            .collect::<Vec<_>>();
160
161        Ok(ProstMetadata(ALPRDMetadata {
162            right_bit_width: array.right_bit_width() as u32,
163            dict_len: array.left_parts_dictionary().len() as u32,
164            dict,
165            left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
166            patches: array
167                .left_parts_patches()
168                .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
169                .transpose()?,
170        }))
171    }
172
173    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
174        Ok(Some(metadata.serialize()))
175    }
176
177    fn deserialize(
178        bytes: &[u8],
179        _dtype: &DType,
180        _len: usize,
181        _buffers: &[BufferHandle],
182        _session: &VortexSession,
183    ) -> VortexResult<Self::Metadata> {
184        Ok(ProstMetadata(
185            <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
186        ))
187    }
188
189    fn build(
190        dtype: &DType,
191        len: usize,
192        metadata: &Self::Metadata,
193        _buffers: &[BufferHandle],
194        children: &dyn ArrayChildren,
195    ) -> VortexResult<ALPRDArray> {
196        if children.len() < 2 {
197            vortex_bail!(
198                "Expected at least 2 children for ALPRD encoding, found {}",
199                children.len()
200            );
201        }
202
203        let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
204        let left_parts = children.get(0, &left_parts_dtype, len)?;
205        let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
206            [0..metadata.0.dict_len as usize]
207            .iter()
208            .map(|&i| {
209                u16::try_from(i)
210                    .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
211            })
212            .try_collect()?;
213
214        let right_parts_dtype = match &dtype {
215            DType::Primitive(PType::F32, _) => {
216                DType::Primitive(PType::U32, Nullability::NonNullable)
217            }
218            DType::Primitive(PType::F64, _) => {
219                DType::Primitive(PType::U64, Nullability::NonNullable)
220            }
221            _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
222        };
223        let right_parts = children.get(1, &right_parts_dtype, len)?;
224
225        let left_parts_patches = metadata
226            .0
227            .patches
228            .map(|p| {
229                let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
230                let values = children.get(3, &left_parts_dtype, p.len()?)?;
231
232                Patches::new(
233                    len,
234                    p.offset()?,
235                    indices,
236                    values,
237                    // TODO(0ax1): handle chunk offsets
238                    None,
239                )
240            })
241            .transpose()?;
242
243        ALPRDArray::try_new(
244            dtype.clone(),
245            left_parts,
246            left_parts_dictionary,
247            right_parts,
248            u8::try_from(metadata.0.right_bit_width).map_err(|_| {
249                vortex_err!(
250                    "right_bit_width {} out of u8 range",
251                    metadata.0.right_bit_width
252                )
253            })?,
254            left_parts_patches,
255        )
256    }
257
258    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
259        // Children: left_parts, right_parts, patches (if present): indices, values
260        let patches_info = array
261            .left_parts_patches
262            .as_ref()
263            .map(|p| (p.array_len(), p.offset()));
264
265        let expected_children = if patches_info.is_some() { 4 } else { 2 };
266
267        vortex_ensure!(
268            children.len() == expected_children,
269            "ALPRDArray expects {} children, got {}",
270            expected_children,
271            children.len()
272        );
273
274        let mut children_iter = children.into_iter();
275        array.left_parts = children_iter
276            .next()
277            .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
278        array.right_parts = children_iter
279            .next()
280            .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
281
282        if let Some((array_len, offset)) = patches_info {
283            let indices = children_iter
284                .next()
285                .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
286            let values = children_iter
287                .next()
288                .ok_or_else(|| vortex_err!("Expected patch values child"))?;
289
290            array.left_parts_patches = Some(Patches::new(
291                array_len, offset, indices, values,
292                None, // chunk_offsets not currently supported for ALPRD
293            )?);
294        }
295
296        Ok(())
297    }
298
299    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
300        let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
301        let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
302
303        // Decode the left_parts using our builtin dictionary.
304        let left_parts_dict = array.left_parts_dictionary();
305
306        let validity = array
307            .left_parts()
308            .validity()?
309            .to_array(array.len())
310            .execute::<Mask>(ctx)?;
311
312        let decoded_array = if array.is_f32() {
313            PrimitiveArray::new(
314                alp_rd_decode::<f32>(
315                    left_parts.into_buffer::<u16>(),
316                    left_parts_dict,
317                    array.right_bit_width,
318                    right_parts.into_buffer_mut::<u32>(),
319                    array.left_parts_patches(),
320                    ctx,
321                )?,
322                Validity::from_mask(validity, array.dtype().nullability()),
323            )
324        } else {
325            PrimitiveArray::new(
326                alp_rd_decode::<f64>(
327                    left_parts.into_buffer::<u16>(),
328                    left_parts_dict,
329                    array.right_bit_width,
330                    right_parts.into_buffer_mut::<u64>(),
331                    array.left_parts_patches(),
332                    ctx,
333                )?,
334                Validity::from_mask(validity, array.dtype().nullability()),
335            )
336        };
337
338        Ok(ExecutionStep::Done(decoded_array.into_array()))
339    }
340
341    fn reduce_parent(
342        array: &Self::Array,
343        parent: &ArrayRef,
344        child_idx: usize,
345    ) -> VortexResult<Option<ArrayRef>> {
346        RULES.evaluate(array, parent, child_idx)
347    }
348
349    fn execute_parent(
350        array: &Self::Array,
351        parent: &ArrayRef,
352        child_idx: usize,
353        ctx: &mut ExecutionCtx,
354    ) -> VortexResult<Option<ArrayRef>> {
355        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
356    }
357}
358
359#[derive(Clone, Debug)]
360pub struct ALPRDArray {
361    dtype: DType,
362    left_parts: ArrayRef,
363    left_parts_patches: Option<Patches>,
364    left_parts_dictionary: Buffer<u16>,
365    right_parts: ArrayRef,
366    right_bit_width: u8,
367    stats_set: ArrayStats,
368}
369
370#[derive(Debug)]
371pub struct ALPRDVTable;
372
373impl ALPRDVTable {
374    pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
375}
376
377impl ALPRDArray {
378    /// Build a new `ALPRDArray` from components.
379    pub fn try_new(
380        dtype: DType,
381        left_parts: ArrayRef,
382        left_parts_dictionary: Buffer<u16>,
383        right_parts: ArrayRef,
384        right_bit_width: u8,
385        left_parts_patches: Option<Patches>,
386    ) -> VortexResult<Self> {
387        if !dtype.is_float() {
388            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
389        }
390
391        let len = left_parts.len();
392        if right_parts.len() != len {
393            vortex_bail!(
394                "left_parts (len {}) and right_parts (len {}) must be of same length",
395                len,
396                right_parts.len()
397            );
398        }
399
400        if !left_parts.dtype().is_unsigned_int() {
401            vortex_bail!("left_parts dtype must be uint");
402        }
403        // we delegate array validity to the left_parts child
404        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
405            vortex_bail!(
406                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
407                dtype,
408                left_parts.dtype()
409            );
410        }
411
412        // we enforce right_parts to be non-nullable uint
413        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
414            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
415        }
416
417        let left_parts_patches = left_parts_patches
418            .map(|patches| {
419                if !patches.values().all_valid()? {
420                    vortex_bail!("patches must be all valid: {}", patches.values());
421                }
422                // TODO(ngates): assert the DType, don't cast it.
423                // TODO(joe): assert the DType, don't cast it in the next PR.
424                let mut patches = patches.cast_values(left_parts.dtype())?;
425                // Force execution of the lazy cast so patch values are materialized
426                // before serialization.
427                *patches.values_mut() = patches.values().to_canonical()?.into_array();
428                Ok(patches)
429            })
430            .transpose()?;
431
432        Ok(Self {
433            dtype,
434            left_parts,
435            left_parts_dictionary,
436            right_parts,
437            right_bit_width,
438            left_parts_patches,
439            stats_set: Default::default(),
440        })
441    }
442
443    /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead
444    /// it constructs it from parts.
445    pub(crate) unsafe fn new_unchecked(
446        dtype: DType,
447        left_parts: ArrayRef,
448        left_parts_dictionary: Buffer<u16>,
449        right_parts: ArrayRef,
450        right_bit_width: u8,
451        left_parts_patches: Option<Patches>,
452    ) -> Self {
453        Self {
454            dtype,
455            left_parts,
456            left_parts_patches,
457            left_parts_dictionary,
458            right_parts,
459            right_bit_width,
460            stats_set: Default::default(),
461        }
462    }
463
464    /// Returns true if logical type of the array values is f32.
465    ///
466    /// Returns false if the logical type of the array values is f64.
467    #[inline]
468    pub fn is_f32(&self) -> bool {
469        matches!(&self.dtype, DType::Primitive(PType::F32, _))
470    }
471
472    /// The leftmost (most significant) bits of the floating point values stored in the array.
473    ///
474    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
475    /// the metadata of this array.
476    pub fn left_parts(&self) -> &ArrayRef {
477        &self.left_parts
478    }
479
480    /// The rightmost (least significant) bits of the floating point values stored in the array.
481    pub fn right_parts(&self) -> &ArrayRef {
482        &self.right_parts
483    }
484
485    #[inline]
486    pub fn right_bit_width(&self) -> u8 {
487        self.right_bit_width
488    }
489
490    /// Patches of left-most bits.
491    pub fn left_parts_patches(&self) -> Option<&Patches> {
492        self.left_parts_patches.as_ref()
493    }
494
495    /// The dictionary that maps the codes in `left_parts` into bit patterns.
496    #[inline]
497    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
498        &self.left_parts_dictionary
499    }
500
501    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
502        self.left_parts_patches = patches;
503    }
504}
505
506impl ValidityChild<ALPRDVTable> for ALPRDVTable {
507    fn validity_child(array: &ALPRDArray) -> &ArrayRef {
508        array.left_parts()
509    }
510}
511
512#[cfg(test)]
513mod test {
514    use rstest::rstest;
515    use vortex_array::ProstMetadata;
516    use vortex_array::ToCanonical;
517    use vortex_array::arrays::PrimitiveArray;
518    use vortex_array::assert_arrays_eq;
519    use vortex_array::dtype::PType;
520    use vortex_array::patches::PatchesMetadata;
521    use vortex_array::test_harness::check_metadata;
522
523    use super::ALPRDMetadata;
524    use crate::ALPRDFloat;
525    use crate::alp_rd;
526
527    #[rstest]
528    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
529    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
530    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
531        #[case] reals: Vec<T>,
532        #[case] seed: T,
533    ) {
534        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
535        // Null out some of the values.
536        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
537        reals[1] = None;
538        reals[5] = None;
539        reals[900] = None;
540
541        // Create a new array from this.
542        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
543
544        // Pick a seed that we know will trigger lots of patches.
545        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
546
547        let rd_array = encoder.encode(&real_array);
548
549        let decoded = rd_array.to_primitive();
550
551        assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
552    }
553
554    #[cfg_attr(miri, ignore)]
555    #[test]
556    fn test_alprd_metadata() {
557        check_metadata(
558            "alprd.metadata",
559            ProstMetadata(ALPRDMetadata {
560                right_bit_width: u32::MAX,
561                patches: Some(PatchesMetadata::new(
562                    usize::MAX,
563                    usize::MAX,
564                    PType::U64,
565                    None,
566                    None,
567                    None,
568                )),
569                dict: Vec::new(),
570                left_parts_ptype: PType::U64 as i32,
571                dict_len: 8,
572            }),
573        );
574    }
575}