1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayRef;
14use vortex_array::DeserializeMetadata;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::SerializeMetadata;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::dtype::PType;
25use vortex_array::patches::Patches;
26use vortex_array::patches::PatchesMetadata;
27use vortex_array::serde::ArrayChildren;
28use vortex_array::stats::ArrayStats;
29use vortex_array::stats::StatsSetRef;
30use vortex_array::validity::Validity;
31use vortex_array::vtable;
32use vortex_array::vtable::ArrayId;
33use vortex_array::vtable::BaseArrayVTable;
34use vortex_array::vtable::VTable;
35use vortex_array::vtable::ValidityChild;
36use vortex_array::vtable::ValidityVTableFromChild;
37use vortex_array::vtable::VisitorVTable;
38use vortex_buffer::Buffer;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_error::vortex_err;
43use vortex_mask::Mask;
44use vortex_session::VortexSession;
45
46use crate::alp_rd::kernel::PARENT_KERNELS;
47use crate::alp_rd::rules::RULES;
48use crate::alp_rd_decode;
49
50vtable!(ALPRD);
51
52#[derive(Clone, prost::Message)]
53pub struct ALPRDMetadata {
54 #[prost(uint32, tag = "1")]
55 right_bit_width: u32,
56 #[prost(uint32, tag = "2")]
57 dict_len: u32,
58 #[prost(uint32, repeated, tag = "3")]
59 dict: Vec<u32>,
60 #[prost(enumeration = "PType", tag = "4")]
61 left_parts_ptype: i32,
62 #[prost(message, tag = "5")]
63 patches: Option<PatchesMetadata>,
64}
65
66impl VTable for ALPRDVTable {
67 type Array = ALPRDArray;
68
69 type Metadata = ProstMetadata<ALPRDMetadata>;
70
71 type ArrayVTable = Self;
72 type OperationsVTable = Self;
73 type ValidityVTable = ValidityVTableFromChild;
74 type VisitorVTable = Self;
75
76 fn id(_array: &Self::Array) -> ArrayId {
77 Self::ID
78 }
79
80 fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
81 let dict = array
82 .left_parts_dictionary()
83 .iter()
84 .map(|&i| i as u32)
85 .collect::<Vec<_>>();
86
87 Ok(ProstMetadata(ALPRDMetadata {
88 right_bit_width: array.right_bit_width() as u32,
89 dict_len: array.left_parts_dictionary().len() as u32,
90 dict,
91 left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
92 patches: array
93 .left_parts_patches()
94 .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
95 .transpose()?,
96 }))
97 }
98
99 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
100 Ok(Some(metadata.serialize()))
101 }
102
103 fn deserialize(
104 bytes: &[u8],
105 _dtype: &DType,
106 _len: usize,
107 _buffers: &[BufferHandle],
108 _session: &VortexSession,
109 ) -> VortexResult<Self::Metadata> {
110 Ok(ProstMetadata(
111 <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
112 ))
113 }
114
115 fn build(
116 dtype: &DType,
117 len: usize,
118 metadata: &Self::Metadata,
119 _buffers: &[BufferHandle],
120 children: &dyn ArrayChildren,
121 ) -> VortexResult<ALPRDArray> {
122 if children.len() < 2 {
123 vortex_bail!(
124 "Expected at least 2 children for ALPRD encoding, found {}",
125 children.len()
126 );
127 }
128
129 let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
130 let left_parts = children.get(0, &left_parts_dtype, len)?;
131 let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
132 [0..metadata.0.dict_len as usize]
133 .iter()
134 .map(|&i| {
135 u16::try_from(i)
136 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
137 })
138 .try_collect()?;
139
140 let right_parts_dtype = match &dtype {
141 DType::Primitive(PType::F32, _) => {
142 DType::Primitive(PType::U32, Nullability::NonNullable)
143 }
144 DType::Primitive(PType::F64, _) => {
145 DType::Primitive(PType::U64, Nullability::NonNullable)
146 }
147 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
148 };
149 let right_parts = children.get(1, &right_parts_dtype, len)?;
150
151 let left_parts_patches = metadata
152 .0
153 .patches
154 .map(|p| {
155 let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
156 let values = children.get(3, &left_parts_dtype, p.len()?)?;
157
158 Patches::new(
159 len,
160 p.offset()?,
161 indices,
162 values,
163 None,
165 )
166 })
167 .transpose()?;
168
169 ALPRDArray::try_new(
170 dtype.clone(),
171 left_parts,
172 left_parts_dictionary,
173 right_parts,
174 u8::try_from(metadata.0.right_bit_width).map_err(|_| {
175 vortex_err!(
176 "right_bit_width {} out of u8 range",
177 metadata.0.right_bit_width
178 )
179 })?,
180 left_parts_patches,
181 )
182 }
183
184 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
185 let patches_info = array
187 .left_parts_patches
188 .as_ref()
189 .map(|p| (p.array_len(), p.offset()));
190
191 let expected_children = if patches_info.is_some() { 4 } else { 2 };
192
193 vortex_ensure!(
194 children.len() == expected_children,
195 "ALPRDArray expects {} children, got {}",
196 expected_children,
197 children.len()
198 );
199
200 let mut children_iter = children.into_iter();
201 array.left_parts = children_iter
202 .next()
203 .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
204 array.right_parts = children_iter
205 .next()
206 .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
207
208 if let Some((array_len, offset)) = patches_info {
209 let indices = children_iter
210 .next()
211 .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
212 let values = children_iter
213 .next()
214 .ok_or_else(|| vortex_err!("Expected patch values child"))?;
215
216 array.left_parts_patches = Some(Patches::new(
217 array_len, offset, indices, values,
218 None, )?);
220 }
221
222 Ok(())
223 }
224
225 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
226 let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
227 let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
228
229 let left_parts_dict = array.left_parts_dictionary();
231
232 let validity = array
233 .left_parts()
234 .validity()?
235 .to_array(array.len())
236 .execute::<Mask>(ctx)?;
237
238 let decoded_array = if array.is_f32() {
239 PrimitiveArray::new(
240 alp_rd_decode::<f32>(
241 left_parts.into_buffer::<u16>(),
242 left_parts_dict,
243 array.right_bit_width,
244 right_parts.into_buffer_mut::<u32>(),
245 array.left_parts_patches(),
246 ctx,
247 )?,
248 Validity::from_mask(validity, array.dtype().nullability()),
249 )
250 } else {
251 PrimitiveArray::new(
252 alp_rd_decode::<f64>(
253 left_parts.into_buffer::<u16>(),
254 left_parts_dict,
255 array.right_bit_width,
256 right_parts.into_buffer_mut::<u64>(),
257 array.left_parts_patches(),
258 ctx,
259 )?,
260 Validity::from_mask(validity, array.dtype().nullability()),
261 )
262 };
263
264 Ok(decoded_array.into_array())
265 }
266
267 fn reduce_parent(
268 array: &Self::Array,
269 parent: &ArrayRef,
270 child_idx: usize,
271 ) -> VortexResult<Option<ArrayRef>> {
272 RULES.evaluate(array, parent, child_idx)
273 }
274
275 fn execute_parent(
276 array: &Self::Array,
277 parent: &ArrayRef,
278 child_idx: usize,
279 ctx: &mut ExecutionCtx,
280 ) -> VortexResult<Option<ArrayRef>> {
281 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
282 }
283}
284
285#[derive(Clone, Debug)]
286pub struct ALPRDArray {
287 dtype: DType,
288 left_parts: ArrayRef,
289 left_parts_patches: Option<Patches>,
290 left_parts_dictionary: Buffer<u16>,
291 right_parts: ArrayRef,
292 right_bit_width: u8,
293 stats_set: ArrayStats,
294}
295
296#[derive(Debug)]
297pub struct ALPRDVTable;
298
299impl ALPRDVTable {
300 pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
301}
302
303impl ALPRDArray {
304 pub fn try_new(
306 dtype: DType,
307 left_parts: ArrayRef,
308 left_parts_dictionary: Buffer<u16>,
309 right_parts: ArrayRef,
310 right_bit_width: u8,
311 left_parts_patches: Option<Patches>,
312 ) -> VortexResult<Self> {
313 if !dtype.is_float() {
314 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
315 }
316
317 let len = left_parts.len();
318 if right_parts.len() != len {
319 vortex_bail!(
320 "left_parts (len {}) and right_parts (len {}) must be of same length",
321 len,
322 right_parts.len()
323 );
324 }
325
326 if !left_parts.dtype().is_unsigned_int() {
327 vortex_bail!("left_parts dtype must be uint");
328 }
329 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
331 vortex_bail!(
332 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
333 dtype,
334 left_parts.dtype()
335 );
336 }
337
338 if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
340 vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
341 }
342
343 let left_parts_patches = left_parts_patches
344 .map(|patches| {
345 if !patches.values().all_valid()? {
346 vortex_bail!("patches must be all valid: {}", patches.values());
347 }
348 let mut patches = patches.cast_values(left_parts.dtype())?;
351 *patches.values_mut() = patches.values().to_canonical()?.into_array();
354 Ok(patches)
355 })
356 .transpose()?;
357
358 Ok(Self {
359 dtype,
360 left_parts,
361 left_parts_dictionary,
362 right_parts,
363 right_bit_width,
364 left_parts_patches,
365 stats_set: Default::default(),
366 })
367 }
368
369 pub(crate) unsafe fn new_unchecked(
372 dtype: DType,
373 left_parts: ArrayRef,
374 left_parts_dictionary: Buffer<u16>,
375 right_parts: ArrayRef,
376 right_bit_width: u8,
377 left_parts_patches: Option<Patches>,
378 ) -> Self {
379 Self {
380 dtype,
381 left_parts,
382 left_parts_patches,
383 left_parts_dictionary,
384 right_parts,
385 right_bit_width,
386 stats_set: Default::default(),
387 }
388 }
389
390 #[inline]
394 pub fn is_f32(&self) -> bool {
395 matches!(&self.dtype, DType::Primitive(PType::F32, _))
396 }
397
398 pub fn left_parts(&self) -> &ArrayRef {
403 &self.left_parts
404 }
405
406 pub fn right_parts(&self) -> &ArrayRef {
408 &self.right_parts
409 }
410
411 #[inline]
412 pub fn right_bit_width(&self) -> u8 {
413 self.right_bit_width
414 }
415
416 pub fn left_parts_patches(&self) -> Option<&Patches> {
418 self.left_parts_patches.as_ref()
419 }
420
421 #[inline]
423 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
424 &self.left_parts_dictionary
425 }
426
427 pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
428 self.left_parts_patches = patches;
429 }
430}
431
432impl ValidityChild<ALPRDVTable> for ALPRDVTable {
433 fn validity_child(array: &ALPRDArray) -> &ArrayRef {
434 array.left_parts()
435 }
436}
437
438impl BaseArrayVTable<ALPRDVTable> for ALPRDVTable {
439 fn len(array: &ALPRDArray) -> usize {
440 array.left_parts.len()
441 }
442
443 fn dtype(array: &ALPRDArray) -> &DType {
444 &array.dtype
445 }
446
447 fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
448 array.stats_set.to_ref(array.as_ref())
449 }
450
451 fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
452 array.dtype.hash(state);
453 array.left_parts.array_hash(state, precision);
454 array.left_parts_dictionary.array_hash(state, precision);
455 array.right_parts.array_hash(state, precision);
456 array.right_bit_width.hash(state);
457 array.left_parts_patches.array_hash(state, precision);
458 }
459
460 fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
461 array.dtype == other.dtype
462 && array.left_parts.array_eq(&other.left_parts, precision)
463 && array
464 .left_parts_dictionary
465 .array_eq(&other.left_parts_dictionary, precision)
466 && array.right_parts.array_eq(&other.right_parts, precision)
467 && array.right_bit_width == other.right_bit_width
468 && array
469 .left_parts_patches
470 .array_eq(&other.left_parts_patches, precision)
471 }
472}
473
474impl VisitorVTable<ALPRDVTable> for ALPRDVTable {
475 fn visit_buffers(_array: &ALPRDArray, _visitor: &mut dyn ArrayBufferVisitor) {}
476
477 fn nbuffers(_array: &ALPRDArray) -> usize {
478 0
479 }
480
481 fn visit_children(array: &ALPRDArray, visitor: &mut dyn ArrayChildVisitor) {
482 visitor.visit_child("left_parts", array.left_parts());
483 visitor.visit_child("right_parts", array.right_parts());
484 if let Some(patches) = array.left_parts_patches() {
485 visitor.visit_patches(patches);
486 }
487 }
488
489 fn nchildren(array: &ALPRDArray) -> usize {
490 2 + array
492 .left_parts_patches()
493 .map_or(0, |p| 2 + p.chunk_offsets().is_some() as usize)
494 }
495}
496
497#[cfg(test)]
498mod test {
499 use rstest::rstest;
500 use vortex_array::ProstMetadata;
501 use vortex_array::ToCanonical;
502 use vortex_array::arrays::PrimitiveArray;
503 use vortex_array::assert_arrays_eq;
504 use vortex_array::dtype::PType;
505 use vortex_array::patches::PatchesMetadata;
506 use vortex_array::test_harness::check_metadata;
507
508 use super::ALPRDMetadata;
509 use crate::ALPRDFloat;
510 use crate::alp_rd;
511
512 #[rstest]
513 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
514 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
515 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
516 #[case] reals: Vec<T>,
517 #[case] seed: T,
518 ) {
519 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
520 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
522 reals[1] = None;
523 reals[5] = None;
524 reals[900] = None;
525
526 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
528
529 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
531
532 let rd_array = encoder.encode(&real_array);
533
534 let decoded = rd_array.to_primitive();
535
536 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
537 }
538
539 #[cfg_attr(miri, ignore)]
540 #[test]
541 fn test_alprd_metadata() {
542 check_metadata(
543 "alprd.metadata",
544 ProstMetadata(ALPRDMetadata {
545 right_bit_width: u32::MAX,
546 patches: Some(PatchesMetadata::new(
547 usize::MAX,
548 usize::MAX,
549 PType::U64,
550 None,
551 None,
552 None,
553 )),
554 dict: Vec::new(),
555 left_parts_ptype: PType::U64 as i32,
556 dict_len: 8,
557 }),
558 );
559 }
560}