1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::DeserializeMetadata;
13use vortex_array::ExecutionCtx;
14use vortex_array::IntoArray;
15use vortex_array::Precision;
16use vortex_array::ProstMetadata;
17use vortex_array::SerializeMetadata;
18use vortex_array::arrays::PrimitiveArray;
19use vortex_array::buffer::BufferHandle;
20use vortex_array::dtype::DType;
21use vortex_array::dtype::Nullability;
22use vortex_array::dtype::PType;
23use vortex_array::patches::Patches;
24use vortex_array::patches::PatchesMetadata;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::stats::ArrayStats;
27use vortex_array::stats::StatsSetRef;
28use vortex_array::validity::Validity;
29use vortex_array::vtable;
30use vortex_array::vtable::ArrayId;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityChild;
33use vortex_array::vtable::ValidityVTableFromChild;
34use vortex_array::vtable::patches_child;
35use vortex_array::vtable::patches_child_name;
36use vortex_array::vtable::patches_nchildren;
37use vortex_buffer::Buffer;
38use vortex_error::VortexResult;
39use vortex_error::vortex_bail;
40use vortex_error::vortex_ensure;
41use vortex_error::vortex_err;
42use vortex_error::vortex_panic;
43use vortex_mask::Mask;
44use vortex_session::VortexSession;
45
46use crate::alp_rd::kernel::PARENT_KERNELS;
47use crate::alp_rd::rules::RULES;
48use crate::alp_rd_decode;
49
50vtable!(ALPRD);
51
52#[derive(Clone, prost::Message)]
53pub struct ALPRDMetadata {
54 #[prost(uint32, tag = "1")]
55 right_bit_width: u32,
56 #[prost(uint32, tag = "2")]
57 dict_len: u32,
58 #[prost(uint32, repeated, tag = "3")]
59 dict: Vec<u32>,
60 #[prost(enumeration = "PType", tag = "4")]
61 left_parts_ptype: i32,
62 #[prost(message, tag = "5")]
63 patches: Option<PatchesMetadata>,
64}
65
66impl VTable for ALPRDVTable {
67 type Array = ALPRDArray;
68
69 type Metadata = ProstMetadata<ALPRDMetadata>;
70 type OperationsVTable = Self;
71 type ValidityVTable = ValidityVTableFromChild;
72
73 fn id(_array: &Self::Array) -> ArrayId {
74 Self::ID
75 }
76
77 fn len(array: &ALPRDArray) -> usize {
78 array.left_parts.len()
79 }
80
81 fn dtype(array: &ALPRDArray) -> &DType {
82 &array.dtype
83 }
84
85 fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
86 array.stats_set.to_ref(array.as_ref())
87 }
88
89 fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
90 array.dtype.hash(state);
91 array.left_parts.array_hash(state, precision);
92 array.left_parts_dictionary.array_hash(state, precision);
93 array.right_parts.array_hash(state, precision);
94 array.right_bit_width.hash(state);
95 array.left_parts_patches.array_hash(state, precision);
96 }
97
98 fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
99 array.dtype == other.dtype
100 && array.left_parts.array_eq(&other.left_parts, precision)
101 && array
102 .left_parts_dictionary
103 .array_eq(&other.left_parts_dictionary, precision)
104 && array.right_parts.array_eq(&other.right_parts, precision)
105 && array.right_bit_width == other.right_bit_width
106 && array
107 .left_parts_patches
108 .array_eq(&other.left_parts_patches, precision)
109 }
110
111 fn nbuffers(_array: &ALPRDArray) -> usize {
112 0
113 }
114
115 fn buffer(_array: &ALPRDArray, idx: usize) -> BufferHandle {
116 vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
117 }
118
119 fn buffer_name(_array: &ALPRDArray, _idx: usize) -> Option<String> {
120 None
121 }
122
123 fn nchildren(array: &ALPRDArray) -> usize {
124 2 + array.left_parts_patches().map_or(0, patches_nchildren)
125 }
126
127 fn child(array: &ALPRDArray, idx: usize) -> ArrayRef {
128 match idx {
129 0 => array.left_parts().clone(),
130 1 => array.right_parts().clone(),
131 _ => {
132 let patches = array
133 .left_parts_patches()
134 .unwrap_or_else(|| vortex_panic!("ALPRDArray child index {idx} out of bounds"));
135 patches_child(patches, idx - 2)
136 }
137 }
138 }
139
140 fn child_name(array: &ALPRDArray, idx: usize) -> String {
141 match idx {
142 0 => "left_parts".to_string(),
143 1 => "right_parts".to_string(),
144 _ => {
145 if array.left_parts_patches().is_none() {
146 vortex_panic!("ALPRDArray child_name index {idx} out of bounds");
147 }
148 patches_child_name(idx - 2).to_string()
149 }
150 }
151 }
152
153 fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
154 let dict = array
155 .left_parts_dictionary()
156 .iter()
157 .map(|&i| i as u32)
158 .collect::<Vec<_>>();
159
160 Ok(ProstMetadata(ALPRDMetadata {
161 right_bit_width: array.right_bit_width() as u32,
162 dict_len: array.left_parts_dictionary().len() as u32,
163 dict,
164 left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
165 patches: array
166 .left_parts_patches()
167 .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
168 .transpose()?,
169 }))
170 }
171
172 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
173 Ok(Some(metadata.serialize()))
174 }
175
176 fn deserialize(
177 bytes: &[u8],
178 _dtype: &DType,
179 _len: usize,
180 _buffers: &[BufferHandle],
181 _session: &VortexSession,
182 ) -> VortexResult<Self::Metadata> {
183 Ok(ProstMetadata(
184 <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
185 ))
186 }
187
188 fn build(
189 dtype: &DType,
190 len: usize,
191 metadata: &Self::Metadata,
192 _buffers: &[BufferHandle],
193 children: &dyn ArrayChildren,
194 ) -> VortexResult<ALPRDArray> {
195 if children.len() < 2 {
196 vortex_bail!(
197 "Expected at least 2 children for ALPRD encoding, found {}",
198 children.len()
199 );
200 }
201
202 let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
203 let left_parts = children.get(0, &left_parts_dtype, len)?;
204 let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
205 [0..metadata.0.dict_len as usize]
206 .iter()
207 .map(|&i| {
208 u16::try_from(i)
209 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
210 })
211 .try_collect()?;
212
213 let right_parts_dtype = match &dtype {
214 DType::Primitive(PType::F32, _) => {
215 DType::Primitive(PType::U32, Nullability::NonNullable)
216 }
217 DType::Primitive(PType::F64, _) => {
218 DType::Primitive(PType::U64, Nullability::NonNullable)
219 }
220 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
221 };
222 let right_parts = children.get(1, &right_parts_dtype, len)?;
223
224 let left_parts_patches = metadata
225 .0
226 .patches
227 .map(|p| {
228 let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
229 let values = children.get(3, &left_parts_dtype, p.len()?)?;
230
231 Patches::new(
232 len,
233 p.offset()?,
234 indices,
235 values,
236 None,
238 )
239 })
240 .transpose()?;
241
242 ALPRDArray::try_new(
243 dtype.clone(),
244 left_parts,
245 left_parts_dictionary,
246 right_parts,
247 u8::try_from(metadata.0.right_bit_width).map_err(|_| {
248 vortex_err!(
249 "right_bit_width {} out of u8 range",
250 metadata.0.right_bit_width
251 )
252 })?,
253 left_parts_patches,
254 )
255 }
256
257 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
258 let patches_info = array
260 .left_parts_patches
261 .as_ref()
262 .map(|p| (p.array_len(), p.offset()));
263
264 let expected_children = if patches_info.is_some() { 4 } else { 2 };
265
266 vortex_ensure!(
267 children.len() == expected_children,
268 "ALPRDArray expects {} children, got {}",
269 expected_children,
270 children.len()
271 );
272
273 let mut children_iter = children.into_iter();
274 array.left_parts = children_iter
275 .next()
276 .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
277 array.right_parts = children_iter
278 .next()
279 .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
280
281 if let Some((array_len, offset)) = patches_info {
282 let indices = children_iter
283 .next()
284 .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
285 let values = children_iter
286 .next()
287 .ok_or_else(|| vortex_err!("Expected patch values child"))?;
288
289 array.left_parts_patches = Some(Patches::new(
290 array_len, offset, indices, values,
291 None, )?);
293 }
294
295 Ok(())
296 }
297
298 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
299 let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
300 let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
301
302 let left_parts_dict = array.left_parts_dictionary();
304
305 let validity = array
306 .left_parts()
307 .validity()?
308 .to_array(array.len())
309 .execute::<Mask>(ctx)?;
310
311 let decoded_array = if array.is_f32() {
312 PrimitiveArray::new(
313 alp_rd_decode::<f32>(
314 left_parts.into_buffer::<u16>(),
315 left_parts_dict,
316 array.right_bit_width,
317 right_parts.into_buffer_mut::<u32>(),
318 array.left_parts_patches(),
319 ctx,
320 )?,
321 Validity::from_mask(validity, array.dtype().nullability()),
322 )
323 } else {
324 PrimitiveArray::new(
325 alp_rd_decode::<f64>(
326 left_parts.into_buffer::<u16>(),
327 left_parts_dict,
328 array.right_bit_width,
329 right_parts.into_buffer_mut::<u64>(),
330 array.left_parts_patches(),
331 ctx,
332 )?,
333 Validity::from_mask(validity, array.dtype().nullability()),
334 )
335 };
336
337 Ok(decoded_array.into_array())
338 }
339
340 fn reduce_parent(
341 array: &Self::Array,
342 parent: &ArrayRef,
343 child_idx: usize,
344 ) -> VortexResult<Option<ArrayRef>> {
345 RULES.evaluate(array, parent, child_idx)
346 }
347
348 fn execute_parent(
349 array: &Self::Array,
350 parent: &ArrayRef,
351 child_idx: usize,
352 ctx: &mut ExecutionCtx,
353 ) -> VortexResult<Option<ArrayRef>> {
354 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
355 }
356}
357
358#[derive(Clone, Debug)]
359pub struct ALPRDArray {
360 dtype: DType,
361 left_parts: ArrayRef,
362 left_parts_patches: Option<Patches>,
363 left_parts_dictionary: Buffer<u16>,
364 right_parts: ArrayRef,
365 right_bit_width: u8,
366 stats_set: ArrayStats,
367}
368
369#[derive(Debug)]
370pub struct ALPRDVTable;
371
372impl ALPRDVTable {
373 pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
374}
375
376impl ALPRDArray {
377 pub fn try_new(
379 dtype: DType,
380 left_parts: ArrayRef,
381 left_parts_dictionary: Buffer<u16>,
382 right_parts: ArrayRef,
383 right_bit_width: u8,
384 left_parts_patches: Option<Patches>,
385 ) -> VortexResult<Self> {
386 if !dtype.is_float() {
387 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
388 }
389
390 let len = left_parts.len();
391 if right_parts.len() != len {
392 vortex_bail!(
393 "left_parts (len {}) and right_parts (len {}) must be of same length",
394 len,
395 right_parts.len()
396 );
397 }
398
399 if !left_parts.dtype().is_unsigned_int() {
400 vortex_bail!("left_parts dtype must be uint");
401 }
402 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
404 vortex_bail!(
405 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
406 dtype,
407 left_parts.dtype()
408 );
409 }
410
411 if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
413 vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
414 }
415
416 let left_parts_patches = left_parts_patches
417 .map(|patches| {
418 if !patches.values().all_valid()? {
419 vortex_bail!("patches must be all valid: {}", patches.values());
420 }
421 let mut patches = patches.cast_values(left_parts.dtype())?;
424 *patches.values_mut() = patches.values().to_canonical()?.into_array();
427 Ok(patches)
428 })
429 .transpose()?;
430
431 Ok(Self {
432 dtype,
433 left_parts,
434 left_parts_dictionary,
435 right_parts,
436 right_bit_width,
437 left_parts_patches,
438 stats_set: Default::default(),
439 })
440 }
441
442 pub(crate) unsafe fn new_unchecked(
445 dtype: DType,
446 left_parts: ArrayRef,
447 left_parts_dictionary: Buffer<u16>,
448 right_parts: ArrayRef,
449 right_bit_width: u8,
450 left_parts_patches: Option<Patches>,
451 ) -> Self {
452 Self {
453 dtype,
454 left_parts,
455 left_parts_patches,
456 left_parts_dictionary,
457 right_parts,
458 right_bit_width,
459 stats_set: Default::default(),
460 }
461 }
462
463 #[inline]
467 pub fn is_f32(&self) -> bool {
468 matches!(&self.dtype, DType::Primitive(PType::F32, _))
469 }
470
471 pub fn left_parts(&self) -> &ArrayRef {
476 &self.left_parts
477 }
478
479 pub fn right_parts(&self) -> &ArrayRef {
481 &self.right_parts
482 }
483
484 #[inline]
485 pub fn right_bit_width(&self) -> u8 {
486 self.right_bit_width
487 }
488
489 pub fn left_parts_patches(&self) -> Option<&Patches> {
491 self.left_parts_patches.as_ref()
492 }
493
494 #[inline]
496 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
497 &self.left_parts_dictionary
498 }
499
500 pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
501 self.left_parts_patches = patches;
502 }
503}
504
505impl ValidityChild<ALPRDVTable> for ALPRDVTable {
506 fn validity_child(array: &ALPRDArray) -> &ArrayRef {
507 array.left_parts()
508 }
509}
510
511#[cfg(test)]
512mod test {
513 use rstest::rstest;
514 use vortex_array::ProstMetadata;
515 use vortex_array::ToCanonical;
516 use vortex_array::arrays::PrimitiveArray;
517 use vortex_array::assert_arrays_eq;
518 use vortex_array::dtype::PType;
519 use vortex_array::patches::PatchesMetadata;
520 use vortex_array::test_harness::check_metadata;
521
522 use super::ALPRDMetadata;
523 use crate::ALPRDFloat;
524 use crate::alp_rd;
525
526 #[rstest]
527 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
528 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
529 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
530 #[case] reals: Vec<T>,
531 #[case] seed: T,
532 ) {
533 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
534 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
536 reals[1] = None;
537 reals[5] = None;
538 reals[900] = None;
539
540 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
542
543 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
545
546 let rd_array = encoder.encode(&real_array);
547
548 let decoded = rd_array.to_primitive();
549
550 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
551 }
552
553 #[cfg_attr(miri, ignore)]
554 #[test]
555 fn test_alprd_metadata() {
556 check_metadata(
557 "alprd.metadata",
558 ProstMetadata(ALPRDMetadata {
559 right_bit_width: u32::MAX,
560 patches: Some(PatchesMetadata::new(
561 usize::MAX,
562 usize::MAX,
563 PType::U64,
564 None,
565 None,
566 None,
567 )),
568 dict: Vec::new(),
569 left_parts_ptype: PType::U64 as i32,
570 dict_len: 8,
571 }),
572 );
573 }
574}