vortex_alp/alp_rd/
array.rs

1use std::fmt::Debug;
2
3use vortex_array::arrays::PrimitiveArray;
4use vortex_array::patches::Patches;
5use vortex_array::stats::{ArrayStats, StatsSetRef};
6use vortex_array::validity::Validity;
7use vortex_array::vtable::VTableRef;
8use vortex_array::{
9    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
10    Canonical, Encoding, SerdeMetadata, ToCanonical,
11};
12use vortex_buffer::Buffer;
13use vortex_dtype::{DType, PType};
14use vortex_error::{VortexResult, vortex_bail};
15use vortex_mask::Mask;
16
17use crate::alp_rd::alp_rd_decode;
18use crate::alp_rd::serde::ALPRDMetadata;
19
20#[derive(Clone, Debug)]
21pub struct ALPRDArray {
22    dtype: DType,
23    left_parts: ArrayRef,
24    left_parts_patches: Option<Patches>,
25    left_parts_dictionary: Buffer<u16>,
26    right_parts: ArrayRef,
27    right_bit_width: u8,
28    stats_set: ArrayStats,
29}
30
31pub struct ALPRDEncoding;
32impl Encoding for ALPRDEncoding {
33    type Array = ALPRDArray;
34    type Metadata = SerdeMetadata<ALPRDMetadata>;
35}
36
37impl ALPRDArray {
38    pub fn try_new(
39        dtype: DType,
40        left_parts: ArrayRef,
41        left_parts_dictionary: Buffer<u16>,
42        right_parts: ArrayRef,
43        right_bit_width: u8,
44        left_parts_patches: Option<Patches>,
45    ) -> VortexResult<Self> {
46        if !dtype.is_float() {
47            vortex_bail!("ALPRDArray given invalid DType ({dtype})");
48        }
49
50        let len = left_parts.len();
51        if right_parts.len() != len {
52            vortex_bail!(
53                "left_parts (len {}) and right_parts (len {}) must be of same length",
54                len,
55                right_parts.len()
56            );
57        }
58
59        if !left_parts.dtype().is_unsigned_int() {
60            vortex_bail!("left_parts dtype must be uint");
61        }
62        // we delegate array validity to the left_parts child
63        if dtype.is_nullable() != left_parts.dtype().is_nullable() {
64            vortex_bail!(
65                "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
66                dtype,
67                left_parts.dtype()
68            );
69        }
70
71        // we enforce right_parts to be non-nullable uint
72        if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
73            vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
74        }
75
76        let left_parts_patches = left_parts_patches
77            .map(|patches| {
78                if !patches.values().all_valid()? {
79                    vortex_bail!("patches must be all valid: {}", patches.values());
80                }
81                // TODO(ngates): assert the DType, don't cast it.
82                patches.cast_values(left_parts.dtype())
83            })
84            .transpose()?;
85
86        Ok(Self {
87            dtype,
88            left_parts,
89            left_parts_patches,
90            left_parts_dictionary,
91            right_parts,
92            right_bit_width,
93            stats_set: Default::default(),
94        })
95    }
96
97    /// Returns true if logical type of the array values is f32.
98    ///
99    /// Returns false if the logical type of the array values is f64.
100    #[inline]
101    pub fn is_f32(&self) -> bool {
102        matches!(&self.dtype, DType::Primitive(PType::F32, _))
103    }
104
105    /// The leftmost (most significant) bits of the floating point values stored in the array.
106    ///
107    /// These are bit-packed and dictionary encoded, and cannot directly be interpreted without
108    /// the metadata of this array.
109    pub fn left_parts(&self) -> &ArrayRef {
110        &self.left_parts
111    }
112
113    /// The rightmost (least significant) bits of the floating point values stored in the array.
114    pub fn right_parts(&self) -> &ArrayRef {
115        &self.right_parts
116    }
117
118    #[inline]
119    pub fn right_bit_width(&self) -> u8 {
120        self.right_bit_width
121    }
122
123    /// Patches of left-most bits.
124    pub fn left_parts_patches(&self) -> Option<&Patches> {
125        self.left_parts_patches.as_ref()
126    }
127
128    /// The dictionary that maps the codes in `left_parts` into bit patterns.
129    #[inline]
130    pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
131        &self.left_parts_dictionary
132    }
133
134    pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
135        self.left_parts_patches = patches;
136    }
137}
138
139impl ArrayImpl for ALPRDArray {
140    type Encoding = ALPRDEncoding;
141
142    fn _len(&self) -> usize {
143        self.left_parts.len()
144    }
145
146    fn _dtype(&self) -> &DType {
147        &self.dtype
148    }
149
150    fn _vtable(&self) -> VTableRef {
151        VTableRef::new_ref(&ALPRDEncoding)
152    }
153
154    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
155        let left_parts = children[0].clone();
156        let right_parts = children[1].clone();
157
158        let left_part_patches = self.left_parts_patches().map(|existing| {
159            let indices = children[2].clone();
160            let values = children[3].clone();
161            Patches::new(existing.array_len(), existing.offset(), indices, values)
162        });
163
164        ALPRDArray::try_new(
165            self.dtype().clone(),
166            left_parts,
167            self.left_parts_dictionary().clone(),
168            right_parts,
169            self.right_bit_width(),
170            left_part_patches,
171        )
172    }
173}
174
175impl ArrayCanonicalImpl for ALPRDArray {
176    fn _to_canonical(&self) -> VortexResult<Canonical> {
177        let left_parts = self.left_parts().to_primitive()?;
178        let right_parts = self.right_parts().to_primitive()?;
179
180        // Decode the left_parts using our builtin dictionary.
181        let left_parts_dict = self.left_parts_dictionary();
182
183        let decoded_array = if self.is_f32() {
184            PrimitiveArray::new(
185                alp_rd_decode::<f32>(
186                    left_parts.into_buffer::<u16>(),
187                    left_parts_dict,
188                    self.right_bit_width,
189                    right_parts.into_buffer_mut::<u32>(),
190                    self.left_parts_patches(),
191                )?,
192                Validity::copy_from_array(self)?,
193            )
194        } else {
195            PrimitiveArray::new(
196                alp_rd_decode::<f64>(
197                    left_parts.into_buffer::<u16>(),
198                    left_parts_dict,
199                    self.right_bit_width,
200                    right_parts.into_buffer_mut::<u64>(),
201                    self.left_parts_patches(),
202                )?,
203                Validity::copy_from_array(self)?,
204            )
205        };
206
207        Ok(Canonical::Primitive(decoded_array))
208    }
209}
210
211impl ArrayStatisticsImpl for ALPRDArray {
212    fn _stats_ref(&self) -> StatsSetRef<'_> {
213        self.stats_set.to_ref(self)
214    }
215}
216
217impl ArrayValidityImpl for ALPRDArray {
218    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
219        self.left_parts().is_valid(index)
220    }
221
222    fn _all_valid(&self) -> VortexResult<bool> {
223        self.left_parts().all_valid()
224    }
225
226    fn _all_invalid(&self) -> VortexResult<bool> {
227        self.left_parts().all_invalid()
228    }
229
230    fn _validity_mask(&self) -> VortexResult<Mask> {
231        self.left_parts().validity_mask()
232    }
233}
234
235#[cfg(test)]
236mod test {
237    use rstest::rstest;
238    use vortex_array::ToCanonical;
239    use vortex_array::arrays::PrimitiveArray;
240
241    use crate::{ALPRDFloat, alp_rd};
242
243    #[rstest]
244    #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
245    #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
246    fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
247        #[case] reals: Vec<T>,
248        #[case] seed: T,
249    ) {
250        assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
251        // Null out some of the values.
252        let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
253        reals[1] = None;
254        reals[5] = None;
255        reals[900] = None;
256
257        // Create a new array from this.
258        let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
259
260        // Pick a seed that we know will trigger lots of patches.
261        let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
262
263        let rd_array = encoder.encode(&real_array);
264
265        let decoded = rd_array.to_primitive().unwrap();
266
267        let maybe_null_reals: Vec<T> = reals.into_iter().map(|v| v.unwrap_or_default()).collect();
268        assert_eq!(decoded.as_slice::<T>(), &maybe_null_reals);
269    }
270}