vortex_alp/alp_rd/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::arrays::PrimitiveArray;
7use vortex_array::patches::Patches;
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::validity::Validity;
10use vortex_array::vtable::{
11    ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityChild, ValidityVTableFromChild,
12};
13use vortex_array::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, PType};
16use vortex_error::{VortexResult, vortex_bail};
17
18use crate::alp_rd::alp_rd_decode;
19
20vtable!(ALPRD);
21
22impl VTable for ALPRDVTable {
23    type Array = ALPRDArray;
24    type Encoding = ALPRDEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = ValidityVTableFromChild;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = Self;
33    type SerdeVTable = Self;
34    type PipelineVTable = NotSupported;
35
36    fn id(_encoding: &Self::Encoding) -> EncodingId {
37        EncodingId::new_ref("vortex.alprd")
38    }
39
40    fn encoding(_array: &Self::Array) -> EncodingRef {
41        EncodingRef::new_ref(ALPRDEncoding.as_ref())
42    }
43}
44
45#[derive(Clone, Debug)]
46pub struct ALPRDArray {
47    dtype: DType,
48    left_parts: ArrayRef,
49    left_parts_patches: Option<Patches>,
50    left_parts_dictionary: Buffer<u16>,
51    right_parts: ArrayRef,
52    right_bit_width: u8,
53    stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct ALPRDEncoding;
58
59impl ALPRDArray {
60    /// Build a new `ALPRDArray` from components.
61    pub fn try_new(
62        dtype: DType,
63        left_parts: ArrayRef,
64        left_parts_dictionary: Buffer<u16>,
65        right_parts: ArrayRef,
66        right_bit_width: u8,
67        left_parts_patches: Option<Patches>,
68    ) -> VortexResult<Self> {
69        if !dtype.is_float() {
70            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
71        }
72
73        let len = left_parts.len();
74        if right_parts.len() != len {
75            vortex_bail!(
76                "left_parts (len {}) and right_parts (len {}) must be of same length",
77                len,
78                right_parts.len()
79            );
80        }
81
82        if !left_parts.dtype().is_unsigned_int() {
83            vortex_bail!("left_parts dtype must be uint");
84        }
85        // we delegate array validity to the left_parts child
86        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
87            vortex_bail!(
88                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
89                dtype,
90                left_parts.dtype()
91            );
92        }
93
94        // we enforce right_parts to be non-nullable uint
95        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
96            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
97        }
98
99        let left_parts_patches = left_parts_patches
100            .map(|patches| {
101                if !patches.values().all_valid() {
102                    vortex_bail!("patches must be all valid: {}", patches.values());
103                }
104                // TODO(ngates): assert the DType, don't cast it.
105                patches.cast_values(left_parts.dtype())
106            })
107            .transpose()?;
108
109        Ok(Self {
110            dtype,
111            left_parts,
112            left_parts_dictionary,
113            right_parts,
114            right_bit_width,
115            left_parts_patches,
116            stats_set: Default::default(),
117        })
118    }
119
120    /// Build a new `ALPRDArray` from components. This does not perform any validation, and instead
121    /// it constructs it from parts.
122    pub(crate) unsafe fn new_unchecked(
123        dtype: DType,
124        left_parts: ArrayRef,
125        left_parts_dictionary: Buffer<u16>,
126        right_parts: ArrayRef,
127        right_bit_width: u8,
128        left_parts_patches: Option<Patches>,
129    ) -> Self {
130        Self {
131            dtype,
132            left_parts,
133            left_parts_patches,
134            left_parts_dictionary,
135            right_parts,
136            right_bit_width,
137            stats_set: Default::default(),
138        }
139    }
140
141    /// Returns true if logical type of the array values is f32.
142    ///
143    /// Returns false if the logical type of the array values is f64.
144    #[inline]
145    pub fn is_f32(&self) -> bool {
146        matches!(&self.dtype, DType::Primitive(PType::F32, _))
147    }
148
149    /// The leftmost (most significant) bits of the floating point values stored in the array.
150    ///
151    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
152    /// the metadata of this array.
153    pub fn left_parts(&self) -> &ArrayRef {
154        &self.left_parts
155    }
156
157    /// The rightmost (least significant) bits of the floating point values stored in the array.
158    pub fn right_parts(&self) -> &ArrayRef {
159        &self.right_parts
160    }
161
162    #[inline]
163    pub fn right_bit_width(&self) -> u8 {
164        self.right_bit_width
165    }
166
167    /// Patches of left-most bits.
168    pub fn left_parts_patches(&self) -> Option<&Patches> {
169        self.left_parts_patches.as_ref()
170    }
171
172    /// The dictionary that maps the codes in `left_parts` into bit patterns.
173    #[inline]
174    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
175        &self.left_parts_dictionary
176    }
177
178    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
179        self.left_parts_patches = patches;
180    }
181}
182
183impl ValidityChild<ALPRDVTable> for ALPRDVTable {
184    fn validity_child(array: &ALPRDArray) -> &dyn Array {
185        array.left_parts()
186    }
187}
188
189impl ArrayVTable<ALPRDVTable> for ALPRDVTable {
190    fn len(array: &ALPRDArray) -> usize {
191        array.left_parts.len()
192    }
193
194    fn dtype(array: &ALPRDArray) -> &DType {
195        &array.dtype
196    }
197
198    fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
199        array.stats_set.to_ref(array.as_ref())
200    }
201}
202
203impl CanonicalVTable<ALPRDVTable> for ALPRDVTable {
204    fn canonicalize(array: &ALPRDArray) -> Canonical {
205        let left_parts = array.left_parts().to_primitive();
206        let right_parts = array.right_parts().to_primitive();
207
208        // Decode the left_parts using our builtin dictionary.
209        let left_parts_dict = array.left_parts_dictionary();
210
211        let decoded_array = if array.is_f32() {
212            PrimitiveArray::new(
213                alp_rd_decode::<f32>(
214                    left_parts.into_buffer::<u16>(),
215                    left_parts_dict,
216                    array.right_bit_width,
217                    right_parts.into_buffer_mut::<u32>(),
218                    array.left_parts_patches(),
219                ),
220                Validity::copy_from_array(array.as_ref()),
221            )
222        } else {
223            PrimitiveArray::new(
224                alp_rd_decode::<f64>(
225                    left_parts.into_buffer::<u16>(),
226                    left_parts_dict,
227                    array.right_bit_width,
228                    right_parts.into_buffer_mut::<u64>(),
229                    array.left_parts_patches(),
230                ),
231                Validity::copy_from_array(array.as_ref()),
232            )
233        };
234
235        Canonical::Primitive(decoded_array)
236    }
237}
238
239#[cfg(test)]
240mod test {
241    use rstest::rstest;
242    use vortex_array::ToCanonical;
243    use vortex_array::arrays::PrimitiveArray;
244
245    use crate::{ALPRDFloat, alp_rd};
246
247    #[rstest]
248    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
249    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
250    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
251        #[case] reals: Vec<T>,
252        #[case] seed: T,
253    ) {
254        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
255        // Null out some of the values.
256        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
257        reals[1] = None;
258        reals[5] = None;
259        reals[900] = None;
260
261        // Create a new array from this.
262        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
263
264        // Pick a seed that we know will trigger lots of patches.
265        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
266
267        let rd_array = encoder.encode(&real_array);
268
269        let decoded = rd_array.to_primitive();
270
271        let maybe_null_reals: Vec<T> = reals.into_iter().map(|v| v.unwrap_or_default()).collect();
272        assert_eq!(decoded.as_slice::<T>(), &maybe_null_reals);
273    }
274}