Skip to main content

vortex_alp/alp_rd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::DeserializeMetadata;
13use vortex_array::DynArray;
14use vortex_array::ExecutionCtx;
15use vortex_array::ExecutionResult;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::SerializeMetadata;
20use vortex_array::arrays::Primitive;
21use vortex_array::arrays::PrimitiveArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::dtype::Nullability;
25use vortex_array::dtype::PType;
26use vortex_array::patches::Patches;
27use vortex_array::patches::PatchesMetadata;
28use vortex_array::require_child;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::VTable;
36use vortex_array::vtable::ValidityChild;
37use vortex_array::vtable::ValidityVTableFromChild;
38use vortex_array::vtable::patches_child;
39use vortex_array::vtable::patches_child_name;
40use vortex_array::vtable::patches_nchildren;
41use vortex_buffer::Buffer;
42use vortex_error::VortexExpect;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_err;
47use vortex_error::vortex_panic;
48use vortex_session::VortexSession;
49
50use crate::alp_rd::kernel::PARENT_KERNELS;
51use crate::alp_rd::rules::RULES;
52use crate::alp_rd_decode;
53
54vtable!(ALPRD);
55
56#[derive(Clone, prost::Message)]
57pub struct ALPRDMetadata {
58    #[prost(uint32, tag = "1")]
59    right_bit_width: u32,
60    #[prost(uint32, tag = "2")]
61    dict_len: u32,
62    #[prost(uint32, repeated, tag = "3")]
63    dict: Vec<u32>,
64    #[prost(enumeration = "PType", tag = "4")]
65    left_parts_ptype: i32,
66    #[prost(message, tag = "5")]
67    patches: Option<PatchesMetadata>,
68}
69
70impl VTable for ALPRD {
71    type Array = ALPRDArray;
72
73    type Metadata = ProstMetadata<ALPRDMetadata>;
74    type OperationsVTable = Self;
75    type ValidityVTable = ValidityVTableFromChild;
76
77    fn vtable(_array: &Self::Array) -> &Self {
78        &ALPRD
79    }
80
81    fn id(&self) -> ArrayId {
82        Self::ID
83    }
84
85    fn len(array: &ALPRDArray) -> usize {
86        array.left_parts.len()
87    }
88
89    fn dtype(array: &ALPRDArray) -> &DType {
90        &array.dtype
91    }
92
93    fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
94        array.stats_set.to_ref(array.as_ref())
95    }
96
97    fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
98        array.dtype.hash(state);
99        array.left_parts.array_hash(state, precision);
100        array.left_parts_dictionary.array_hash(state, precision);
101        array.right_parts.array_hash(state, precision);
102        array.right_bit_width.hash(state);
103        array.left_parts_patches.array_hash(state, precision);
104    }
105
106    fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
107        array.dtype == other.dtype
108            && array.left_parts.array_eq(&other.left_parts, precision)
109            && array
110                .left_parts_dictionary
111                .array_eq(&other.left_parts_dictionary, precision)
112            && array.right_parts.array_eq(&other.right_parts, precision)
113            && array.right_bit_width == other.right_bit_width
114            && array
115                .left_parts_patches
116                .array_eq(&other.left_parts_patches, precision)
117    }
118
119    fn nbuffers(_array: &ALPRDArray) -> usize {
120        0
121    }
122
123    fn buffer(_array: &ALPRDArray, idx: usize) -> BufferHandle {
124        vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
125    }
126
127    fn buffer_name(_array: &ALPRDArray, _idx: usize) -> Option<String> {
128        None
129    }
130
131    fn nchildren(array: &ALPRDArray) -> usize {
132        2 + array.left_parts_patches().map_or(0, patches_nchildren)
133    }
134
135    fn child(array: &ALPRDArray, idx: usize) -> ArrayRef {
136        match idx {
137            0 => array.left_parts().clone(),
138            1 => array.right_parts().clone(),
139            _ => {
140                let patches = array
141                    .left_parts_patches()
142                    .unwrap_or_else(|| vortex_panic!("ALPRDArray child index {idx} out of bounds"));
143                patches_child(patches, idx - 2)
144            }
145        }
146    }
147
148    fn child_name(array: &ALPRDArray, idx: usize) -> String {
149        match idx {
150            0 => "left_parts".to_string(),
151            1 => "right_parts".to_string(),
152            _ => {
153                if array.left_parts_patches().is_none() {
154                    vortex_panic!("ALPRDArray child_name index {idx} out of bounds");
155                }
156                patches_child_name(idx - 2).to_string()
157            }
158        }
159    }
160
161    fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
162        let dict = array
163            .left_parts_dictionary()
164            .iter()
165            .map(|&i| i as u32)
166            .collect::<Vec<_>>();
167
168        Ok(ProstMetadata(ALPRDMetadata {
169            right_bit_width: array.right_bit_width() as u32,
170            dict_len: array.left_parts_dictionary().len() as u32,
171            dict,
172            left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
173            patches: array
174                .left_parts_patches()
175                .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
176                .transpose()?,
177        }))
178    }
179
180    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
181        Ok(Some(metadata.serialize()))
182    }
183
184    fn deserialize(
185        bytes: &[u8],
186        _dtype: &DType,
187        _len: usize,
188        _buffers: &[BufferHandle],
189        _session: &VortexSession,
190    ) -> VortexResult<Self::Metadata> {
191        Ok(ProstMetadata(
192            <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
193        ))
194    }
195
196    fn build(
197        dtype: &DType,
198        len: usize,
199        metadata: &Self::Metadata,
200        _buffers: &[BufferHandle],
201        children: &dyn ArrayChildren,
202    ) -> VortexResult<ALPRDArray> {
203        if children.len() < 2 {
204            vortex_bail!(
205                "Expected at least 2 children for ALPRD encoding, found {}",
206                children.len()
207            );
208        }
209
210        let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
211        let left_parts = children.get(0, &left_parts_dtype, len)?;
212        let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
213            [0..metadata.0.dict_len as usize]
214            .iter()
215            .map(|&i| {
216                u16::try_from(i)
217                    .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
218            })
219            .try_collect()?;
220
221        let right_parts_dtype = match &dtype {
222            DType::Primitive(PType::F32, _) => {
223                DType::Primitive(PType::U32, Nullability::NonNullable)
224            }
225            DType::Primitive(PType::F64, _) => {
226                DType::Primitive(PType::U64, Nullability::NonNullable)
227            }
228            _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
229        };
230        let right_parts = children.get(1, &right_parts_dtype, len)?;
231
232        let left_parts_patches = metadata
233            .0
234            .patches
235            .map(|p| {
236                let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
237                let values = children.get(3, &left_parts_dtype, p.len()?)?;
238
239                Patches::new(
240                    len,
241                    p.offset()?,
242                    indices,
243                    values,
244                    // TODO(0ax1): handle chunk offsets
245                    None,
246                )
247            })
248            .transpose()?;
249
250        ALPRDArray::try_new(
251            dtype.clone(),
252            left_parts,
253            left_parts_dictionary,
254            right_parts,
255            u8::try_from(metadata.0.right_bit_width).map_err(|_| {
256                vortex_err!(
257                    "right_bit_width {} out of u8 range",
258                    metadata.0.right_bit_width
259                )
260            })?,
261            left_parts_patches,
262        )
263    }
264
265    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
266        // Children: left_parts, right_parts, patches (if present): indices, values
267        let patches_info = array
268            .left_parts_patches
269            .as_ref()
270            .map(|p| (p.array_len(), p.offset()));
271
272        let expected_children = if patches_info.is_some() { 4 } else { 2 };
273
274        vortex_ensure!(
275            children.len() == expected_children,
276            "ALPRDArray expects {} children, got {}",
277            expected_children,
278            children.len()
279        );
280
281        let mut children_iter = children.into_iter();
282        array.left_parts = children_iter
283            .next()
284            .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
285        array.right_parts = children_iter
286            .next()
287            .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
288
289        if let Some((array_len, offset)) = patches_info {
290            let indices = children_iter
291                .next()
292                .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
293            let values = children_iter
294                .next()
295                .ok_or_else(|| vortex_err!("Expected patch values child"))?;
296
297            array.left_parts_patches = Some(Patches::new(
298                array_len, offset, indices, values,
299                None, // chunk_offsets not currently supported for ALPRD
300            )?);
301        }
302
303        Ok(())
304    }
305
306    fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
307        let array = require_child!(Self, array, array.left_parts(), 0 => Primitive);
308        let array = require_child!(Self, array, array.right_parts(), 1 => Primitive);
309
310        let right_bit_width = array.right_bit_width();
311        let ALPRDArrayParts {
312            left_parts,
313            right_parts,
314            left_parts_dictionary,
315            left_parts_patches,
316            dtype,
317            ..
318        } = Arc::unwrap_or_clone(array).into_parts();
319        let ptype = dtype.as_ptype();
320
321        let left_parts = left_parts
322            .try_into::<Primitive>()
323            .ok()
324            .vortex_expect("ALPRD execute: left_parts is primitive");
325        let right_parts = right_parts
326            .try_into::<Primitive>()
327            .ok()
328            .vortex_expect("ALPRD execute: right_parts is primitive");
329
330        // Decode the left_parts using our builtin dictionary.
331        let left_parts_dict = left_parts_dictionary;
332        let validity = left_parts.validity_mask()?;
333
334        let decoded_array = if ptype == PType::F32 {
335            // TODO(joe): use iterative execution for the patches.
336            PrimitiveArray::new(
337                alp_rd_decode::<f32>(
338                    left_parts.into_buffer::<u16>(),
339                    &left_parts_dict,
340                    right_bit_width,
341                    right_parts.into_buffer::<u32>(),
342                    left_parts_patches,
343                    ctx,
344                )?,
345                Validity::from_mask(validity, dtype.nullability()),
346            )
347        } else {
348            PrimitiveArray::new(
349                alp_rd_decode::<f64>(
350                    left_parts.into_buffer::<u16>(),
351                    &left_parts_dict,
352                    right_bit_width,
353                    right_parts.into_buffer::<u64>(),
354                    left_parts_patches,
355                    ctx,
356                )?,
357                Validity::from_mask(validity, dtype.nullability()),
358            )
359        };
360
361        Ok(ExecutionResult::done(decoded_array.into_array()))
362    }
363
364    fn reduce_parent(
365        array: &Self::Array,
366        parent: &ArrayRef,
367        child_idx: usize,
368    ) -> VortexResult<Option<ArrayRef>> {
369        RULES.evaluate(array, parent, child_idx)
370    }
371
372    fn execute_parent(
373        array: &Self::Array,
374        parent: &ArrayRef,
375        child_idx: usize,
376        ctx: &mut ExecutionCtx,
377    ) -> VortexResult<Option<ArrayRef>> {
378        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
379    }
380}
381
382#[derive(Clone, Debug)]
383pub struct ALPRDArray {
384    dtype: DType,
385    left_parts: ArrayRef,
386    left_parts_patches: Option<Patches>,
387    left_parts_dictionary: Buffer<u16>,
388    right_parts: ArrayRef,
389    right_bit_width: u8,
390    stats_set: ArrayStats,
391}
392
393#[derive(Clone, Debug)]
394pub struct ALPRDArrayParts {
395    pub dtype: DType,
396    pub left_parts: ArrayRef,
397    pub left_parts_patches: Option<Patches>,
398    pub left_parts_dictionary: Buffer<u16>,
399    pub right_parts: ArrayRef,
400}
401
402#[derive(Clone, Debug)]
403pub struct ALPRD;
404
405impl ALPRD {
406    pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
407}
408
409impl ALPRDArray {
410    /// Build a new `ALPRDArray` from components.
411    pub fn try_new(
412        dtype: DType,
413        left_parts: ArrayRef,
414        left_parts_dictionary: Buffer<u16>,
415        right_parts: ArrayRef,
416        right_bit_width: u8,
417        left_parts_patches: Option<Patches>,
418    ) -> VortexResult<Self> {
419        if !dtype.is_float() {
420            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
421        }
422
423        let len = left_parts.len();
424        if right_parts.len() != len {
425            vortex_bail!(
426                "left_parts (len {}) and right_parts (len {}) must be of same length",
427                len,
428                right_parts.len()
429            );
430        }
431
432        if !left_parts.dtype().is_unsigned_int() {
433            vortex_bail!("left_parts dtype must be uint");
434        }
435        // we delegate array validity to the left_parts child
436        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
437            vortex_bail!(
438                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
439                dtype,
440                left_parts.dtype()
441            );
442        }
443
444        // we enforce right_parts to be non-nullable uint
445        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
446            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
447        }
448
449        let left_parts_patches = left_parts_patches
450            .map(|patches| {
451                if !patches.values().all_valid()? {
452                    vortex_bail!("patches must be all valid: {}", patches.values());
453                }
454                // TODO(ngates): assert the DType, don't cast it.
455                // TODO(joe): assert the DType, don't cast it in the next PR.
456                let mut patches = patches.cast_values(left_parts.dtype())?;
457                // Force execution of the lazy cast so patch values are materialized
458                // before serialization.
459                *patches.values_mut() = patches.values().to_canonical()?.into_array();
460                Ok(patches)
461            })
462            .transpose()?;
463
464        Ok(Self {
465            dtype,
466            left_parts,
467            left_parts_dictionary,
468            right_parts,
469            right_bit_width,
470            left_parts_patches,
471            stats_set: Default::default(),
472        })
473    }
474
475    /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead
476    /// it constructs it from parts.
477    pub(crate) unsafe fn new_unchecked(
478        dtype: DType,
479        left_parts: ArrayRef,
480        left_parts_dictionary: Buffer<u16>,
481        right_parts: ArrayRef,
482        right_bit_width: u8,
483        left_parts_patches: Option<Patches>,
484    ) -> Self {
485        Self {
486            dtype,
487            left_parts,
488            left_parts_patches,
489            left_parts_dictionary,
490            right_parts,
491            right_bit_width,
492            stats_set: Default::default(),
493        }
494    }
495
496    /// Return all the owned parts of the array
497    pub fn into_parts(self) -> ALPRDArrayParts {
498        ALPRDArrayParts {
499            dtype: self.dtype,
500            left_parts: self.left_parts,
501            left_parts_patches: self.left_parts_patches,
502            left_parts_dictionary: self.left_parts_dictionary,
503            right_parts: self.right_parts,
504        }
505    }
506
507    /// Returns true if logical type of the array values is f32.
508    ///
509    /// Returns false if the logical type of the array values is f64.
510    #[inline]
511    pub fn is_f32(&self) -> bool {
512        matches!(&self.dtype, DType::Primitive(PType::F32, _))
513    }
514
515    /// The leftmost (most significant) bits of the floating point values stored in the array.
516    ///
517    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
518    /// the metadata of this array.
519    pub fn left_parts(&self) -> &ArrayRef {
520        &self.left_parts
521    }
522
523    /// The rightmost (least significant) bits of the floating point values stored in the array.
524    pub fn right_parts(&self) -> &ArrayRef {
525        &self.right_parts
526    }
527
528    #[inline]
529    pub fn right_bit_width(&self) -> u8 {
530        self.right_bit_width
531    }
532
533    /// Patches of left-most bits.
534    pub fn left_parts_patches(&self) -> Option<&Patches> {
535        self.left_parts_patches.as_ref()
536    }
537
538    /// The dictionary that maps the codes in `left_parts` into bit patterns.
539    #[inline]
540    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
541        &self.left_parts_dictionary
542    }
543
544    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
545        self.left_parts_patches = patches;
546    }
547}
548
549impl ValidityChild<ALPRD> for ALPRD {
550    fn validity_child(array: &ALPRDArray) -> &ArrayRef {
551        array.left_parts()
552    }
553}
554
555#[cfg(test)]
556mod test {
557    use rstest::rstest;
558    use vortex_array::ProstMetadata;
559    use vortex_array::ToCanonical;
560    use vortex_array::arrays::PrimitiveArray;
561    use vortex_array::assert_arrays_eq;
562    use vortex_array::dtype::PType;
563    use vortex_array::patches::PatchesMetadata;
564    use vortex_array::test_harness::check_metadata;
565
566    use super::ALPRDMetadata;
567    use crate::ALPRDFloat;
568    use crate::alp_rd;
569
570    #[rstest]
571    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
572    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
573    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
574        #[case] reals: Vec<T>,
575        #[case] seed: T,
576    ) {
577        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
578        // Null out some of the values.
579        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
580        reals[1] = None;
581        reals[5] = None;
582        reals[900] = None;
583
584        // Create a new array from this.
585        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
586
587        // Pick a seed that we know will trigger lots of patches.
588        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
589
590        let rd_array = encoder.encode(&real_array);
591
592        let decoded = rd_array.to_primitive();
593
594        assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
595    }
596
597    #[cfg_attr(miri, ignore)]
598    #[test]
599    fn test_alprd_metadata() {
600        check_metadata(
601            "alprd.metadata",
602            ProstMetadata(ALPRDMetadata {
603                right_bit_width: u32::MAX,
604                patches: Some(PatchesMetadata::new(
605                    usize::MAX,
606                    usize::MAX,
607                    PType::U64,
608                    None,
609                    None,
610                    None,
611                )),
612                dict: Vec::new(),
613                left_parts_ptype: PType::U64 as i32,
614                dict_len: 8,
615            }),
616        );
617    }
618}