vortex_alp/alp_rd/
array.rs1use std::fmt::Debug;
5
6use vortex_array::arrays::PrimitiveArray;
7use vortex_array::patches::Patches;
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::validity::Validity;
10use vortex_array::vtable::{
11 ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityChild, ValidityVTableFromChild,
12};
13use vortex_array::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, PType};
16use vortex_error::{VortexResult, vortex_bail};
17
18use crate::alp_rd::alp_rd_decode;
19
20vtable!(ALPRD);
21
22impl VTable for ALPRDVTable {
23 type Array = ALPRDArray;
24 type Encoding = ALPRDEncoding;
25
26 type ArrayVTable = Self;
27 type CanonicalVTable = Self;
28 type OperationsVTable = Self;
29 type ValidityVTable = ValidityVTableFromChild;
30 type VisitorVTable = Self;
31 type ComputeVTable = NotSupported;
32 type EncodeVTable = Self;
33 type SerdeVTable = Self;
34 type PipelineVTable = NotSupported;
35
36 fn id(_encoding: &Self::Encoding) -> EncodingId {
37 EncodingId::new_ref("vortex.alprd")
38 }
39
40 fn encoding(_array: &Self::Array) -> EncodingRef {
41 EncodingRef::new_ref(ALPRDEncoding.as_ref())
42 }
43}
44
45#[derive(Clone, Debug)]
46pub struct ALPRDArray {
47 dtype: DType,
48 left_parts: ArrayRef,
49 left_parts_patches: Option<Patches>,
50 left_parts_dictionary: Buffer<u16>,
51 right_parts: ArrayRef,
52 right_bit_width: u8,
53 stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct ALPRDEncoding;
58
59impl ALPRDArray {
60 pub fn try_new(
62 dtype: DType,
63 left_parts: ArrayRef,
64 left_parts_dictionary: Buffer<u16>,
65 right_parts: ArrayRef,
66 right_bit_width: u8,
67 left_parts_patches: Option<Patches>,
68 ) -> VortexResult<Self> {
69 if !dtype.is_float() {
70 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
71 }
72
73 let len = left_parts.len();
74 if right_parts.len() != len {
75 vortex_bail!(
76 "left_parts (len {}) and right_parts (len {}) must be of same length",
77 len,
78 right_parts.len()
79 );
80 }
81
82 if !left_parts.dtype().is_unsigned_int() {
83 vortex_bail!("left_parts dtype must be uint");
84 }
85 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
87 vortex_bail!(
88 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
89 dtype,
90 left_parts.dtype()
91 );
92 }
93
94 if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
96 vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
97 }
98
99 let left_parts_patches = left_parts_patches
100 .map(|patches| {
101 if !patches.values().all_valid() {
102 vortex_bail!("patches must be all valid: {}", patches.values());
103 }
104 patches.cast_values(left_parts.dtype())
106 })
107 .transpose()?;
108
109 Ok(Self {
110 dtype,
111 left_parts,
112 left_parts_dictionary,
113 right_parts,
114 right_bit_width,
115 left_parts_patches,
116 stats_set: Default::default(),
117 })
118 }
119
120 pub(crate) unsafe fn new_unchecked(
123 dtype: DType,
124 left_parts: ArrayRef,
125 left_parts_dictionary: Buffer<u16>,
126 right_parts: ArrayRef,
127 right_bit_width: u8,
128 left_parts_patches: Option<Patches>,
129 ) -> Self {
130 Self {
131 dtype,
132 left_parts,
133 left_parts_patches,
134 left_parts_dictionary,
135 right_parts,
136 right_bit_width,
137 stats_set: Default::default(),
138 }
139 }
140
141 #[inline]
145 pub fn is_f32(&self) -> bool {
146 matches!(&self.dtype, DType::Primitive(PType::F32, _))
147 }
148
149 pub fn left_parts(&self) -> &ArrayRef {
154 &self.left_parts
155 }
156
157 pub fn right_parts(&self) -> &ArrayRef {
159 &self.right_parts
160 }
161
162 #[inline]
163 pub fn right_bit_width(&self) -> u8 {
164 self.right_bit_width
165 }
166
167 pub fn left_parts_patches(&self) -> Option<&Patches> {
169 self.left_parts_patches.as_ref()
170 }
171
172 #[inline]
174 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
175 &self.left_parts_dictionary
176 }
177
178 pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
179 self.left_parts_patches = patches;
180 }
181}
182
183impl ValidityChild<ALPRDVTable> for ALPRDVTable {
184 fn validity_child(array: &ALPRDArray) -> &dyn Array {
185 array.left_parts()
186 }
187}
188
189impl ArrayVTable<ALPRDVTable> for ALPRDVTable {
190 fn len(array: &ALPRDArray) -> usize {
191 array.left_parts.len()
192 }
193
194 fn dtype(array: &ALPRDArray) -> &DType {
195 &array.dtype
196 }
197
198 fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
199 array.stats_set.to_ref(array.as_ref())
200 }
201}
202
203impl CanonicalVTable<ALPRDVTable> for ALPRDVTable {
204 fn canonicalize(array: &ALPRDArray) -> Canonical {
205 let left_parts = array.left_parts().to_primitive();
206 let right_parts = array.right_parts().to_primitive();
207
208 let left_parts_dict = array.left_parts_dictionary();
210
211 let decoded_array = if array.is_f32() {
212 PrimitiveArray::new(
213 alp_rd_decode::<f32>(
214 left_parts.into_buffer::<u16>(),
215 left_parts_dict,
216 array.right_bit_width,
217 right_parts.into_buffer_mut::<u32>(),
218 array.left_parts_patches(),
219 ),
220 Validity::copy_from_array(array.as_ref()),
221 )
222 } else {
223 PrimitiveArray::new(
224 alp_rd_decode::<f64>(
225 left_parts.into_buffer::<u16>(),
226 left_parts_dict,
227 array.right_bit_width,
228 right_parts.into_buffer_mut::<u64>(),
229 array.left_parts_patches(),
230 ),
231 Validity::copy_from_array(array.as_ref()),
232 )
233 };
234
235 Canonical::Primitive(decoded_array)
236 }
237}
238
239#[cfg(test)]
240mod test {
241 use rstest::rstest;
242 use vortex_array::ToCanonical;
243 use vortex_array::arrays::PrimitiveArray;
244
245 use crate::{ALPRDFloat, alp_rd};
246
247 #[rstest]
248 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
249 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
250 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
251 #[case] reals: Vec<T>,
252 #[case] seed: T,
253 ) {
254 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
255 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
257 reals[1] = None;
258 reals[5] = None;
259 reals[900] = None;
260
261 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
263
264 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
266
267 let rd_array = encoder.encode(&real_array);
268
269 let decoded = rd_array.to_primitive();
270
271 let maybe_null_reals: Vec<T> = reals.into_iter().map(|v| v.unwrap_or_default()).collect();
272 assert_eq!(decoded.as_slice::<T>(), &maybe_null_reals);
273 }
274}