1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use itertools::Itertools;
11use prost::Message;
12use vortex_array::Array;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayId;
16use vortex_array::ArrayParts;
17use vortex_array::ArrayRef;
18use vortex_array::ArrayView;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionResult;
21use vortex_array::IntoArray;
22use vortex_array::Precision;
23use vortex_array::TypedArrayRef;
24use vortex_array::arrays::Primitive;
25use vortex_array::arrays::PrimitiveArray;
26use vortex_array::buffer::BufferHandle;
27use vortex_array::dtype::DType;
28use vortex_array::dtype::Nullability;
29use vortex_array::dtype::PType;
30use vortex_array::patches::Patches;
31use vortex_array::patches::PatchesMetadata;
32use vortex_array::require_child;
33use vortex_array::require_patches;
34use vortex_array::serde::ArrayChildren;
35use vortex_array::validity::Validity;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityChild;
38use vortex_array::vtable::ValidityVTableFromChild;
39use vortex_buffer::Buffer;
40use vortex_error::VortexExpect;
41use vortex_error::VortexResult;
42use vortex_error::vortex_bail;
43use vortex_error::vortex_ensure;
44use vortex_error::vortex_err;
45use vortex_error::vortex_panic;
46use vortex_session::VortexSession;
47
48use crate::alp_rd::kernel::PARENT_KERNELS;
49use crate::alp_rd::rules::RULES;
50use crate::alp_rd_decode;
51
52pub type ALPRDArray = Array<ALPRD>;
54
55#[derive(Clone, prost::Message)]
56pub struct ALPRDMetadata {
57 #[prost(uint32, tag = "1")]
58 right_bit_width: u32,
59 #[prost(uint32, tag = "2")]
60 dict_len: u32,
61 #[prost(uint32, repeated, tag = "3")]
62 dict: Vec<u32>,
63 #[prost(enumeration = "PType", tag = "4")]
64 left_parts_ptype: i32,
65 #[prost(message, tag = "5")]
66 patches: Option<PatchesMetadata>,
67}
68
69impl ArrayHash for ALPRDData {
70 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
71 self.left_parts_dictionary.array_hash(state, precision);
72 self.right_bit_width.hash(state);
73 self.patch_offset.hash(state);
74 self.patch_offset_within_chunk.hash(state);
75 }
76}
77
78impl ArrayEq for ALPRDData {
79 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
80 self.left_parts_dictionary
81 .array_eq(&other.left_parts_dictionary, precision)
82 && self.right_bit_width == other.right_bit_width
83 && self.patch_offset == other.patch_offset
84 && self.patch_offset_within_chunk == other.patch_offset_within_chunk
85 }
86}
87
88impl VTable for ALPRD {
89 type ArrayData = ALPRDData;
90
91 type OperationsVTable = Self;
92 type ValidityVTable = ValidityVTableFromChild;
93
94 fn id(&self) -> ArrayId {
95 Self::ID
96 }
97
98 fn validate(
99 &self,
100 data: &ALPRDData,
101 dtype: &DType,
102 len: usize,
103 slots: &[Option<ArrayRef>],
104 ) -> VortexResult<()> {
105 validate_parts(
106 dtype,
107 len,
108 left_parts_from_slots(slots),
109 right_parts_from_slots(slots),
110 patches_from_slots(
111 slots,
112 data.patch_offset,
113 data.patch_offset_within_chunk,
114 len,
115 )
116 .as_ref(),
117 )
118 }
119
120 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
121 0
122 }
123
124 fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
125 vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
126 }
127
128 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
129 None
130 }
131
132 fn serialize(
133 array: ArrayView<'_, Self>,
134 _session: &VortexSession,
135 ) -> VortexResult<Option<Vec<u8>>> {
136 let dict = array
137 .left_parts_dictionary()
138 .iter()
139 .map(|&i| i as u32)
140 .collect::<Vec<_>>();
141
142 Ok(Some(
143 ALPRDMetadata {
144 right_bit_width: array.right_bit_width() as u32,
145 dict_len: array.left_parts_dictionary().len() as u32,
146 dict,
147 left_parts_ptype: array.left_parts().dtype().as_ptype() as i32,
148 patches: array
149 .left_parts_patches()
150 .map(|p| p.to_metadata(array.len(), p.dtype()))
151 .transpose()?,
152 }
153 .encode_to_vec(),
154 ))
155 }
156
157 fn deserialize(
158 &self,
159 dtype: &DType,
160 len: usize,
161 metadata: &[u8],
162 _buffers: &[BufferHandle],
163 children: &dyn ArrayChildren,
164 _session: &VortexSession,
165 ) -> VortexResult<ArrayParts<Self>> {
166 let metadata = ALPRDMetadata::decode(metadata)?;
167 if children.len() < 2 {
168 vortex_bail!(
169 "Expected at least 2 children for ALPRD encoding, found {}",
170 children.len()
171 );
172 }
173
174 let left_parts_dtype = DType::Primitive(metadata.left_parts_ptype(), dtype.nullability());
175 let left_parts = children.get(0, &left_parts_dtype, len)?;
176 let left_parts_dictionary: Buffer<u16> = metadata.dict.as_slice()
177 [0..metadata.dict_len as usize]
178 .iter()
179 .map(|&i| {
180 u16::try_from(i)
181 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
182 })
183 .try_collect()?;
184
185 let right_parts_dtype = match &dtype {
186 DType::Primitive(PType::F32, _) => {
187 DType::Primitive(PType::U32, Nullability::NonNullable)
188 }
189 DType::Primitive(PType::F64, _) => {
190 DType::Primitive(PType::U64, Nullability::NonNullable)
191 }
192 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
193 };
194 let right_parts = children.get(1, &right_parts_dtype, len)?;
195
196 let left_parts_patches = metadata
197 .patches
198 .map(|p| {
199 let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
200 let values = children.get(3, &left_parts_dtype.as_nonnullable(), p.len()?)?;
201
202 Patches::new(
203 len,
204 p.offset()?,
205 indices,
206 values,
207 None,
209 )
210 })
211 .transpose()?;
212 let left_parts_patches = ALPRDData::canonicalize_patches(&left_parts, left_parts_patches)?;
213 let slots = ALPRDData::make_slots(&left_parts, &right_parts, &left_parts_patches);
214 let data = ALPRDData::new(
215 left_parts_dictionary,
216 u8::try_from(metadata.right_bit_width).map_err(|_| {
217 vortex_err!(
218 "right_bit_width {} out of u8 range",
219 metadata.right_bit_width
220 )
221 })?,
222 left_parts_patches,
223 );
224 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
225 }
226
227 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
228 SLOT_NAMES[idx].to_string()
229 }
230
231 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
232 let array = require_child!(array, array.left_parts(), 0 => Primitive);
233 let array = require_child!(array, array.right_parts(), 1 => Primitive);
234 require_patches!(
235 array,
236 LP_PATCH_INDICES_SLOT,
237 LP_PATCH_VALUES_SLOT,
238 LP_PATCH_CHUNK_OFFSETS_SLOT
239 );
240
241 let dtype = array.dtype().clone();
242 let right_bit_width = array.right_bit_width();
243 let ALPRDDataParts {
244 left_parts,
245 right_parts,
246 left_parts_dictionary,
247 left_parts_patches,
248 } = ALPRDArrayOwnedExt::into_data_parts(array);
249 let ptype = dtype.as_ptype();
250
251 let left_parts = left_parts
252 .try_downcast::<Primitive>()
253 .ok()
254 .vortex_expect("ALPRD execute: left_parts is primitive");
255 let right_parts = right_parts
256 .try_downcast::<Primitive>()
257 .ok()
258 .vortex_expect("ALPRD execute: right_parts is primitive");
259
260 let left_parts_dict = left_parts_dictionary;
262 let validity = left_parts.validity_mask()?;
263
264 let decoded_array = if ptype == PType::F32 {
265 PrimitiveArray::new(
266 alp_rd_decode::<f32>(
267 left_parts.into_buffer_mut::<u16>(),
268 &left_parts_dict,
269 right_bit_width,
270 right_parts.into_buffer_mut::<u32>(),
271 left_parts_patches,
272 ctx,
273 )?,
274 Validity::from_mask(validity, dtype.nullability()),
275 )
276 } else {
277 PrimitiveArray::new(
278 alp_rd_decode::<f64>(
279 left_parts.into_buffer_mut::<u16>(),
280 &left_parts_dict,
281 right_bit_width,
282 right_parts.into_buffer_mut::<u64>(),
283 left_parts_patches,
284 ctx,
285 )?,
286 Validity::from_mask(validity, dtype.nullability()),
287 )
288 };
289
290 Ok(ExecutionResult::done(decoded_array.into_array()))
291 }
292
293 fn reduce_parent(
294 array: ArrayView<'_, Self>,
295 parent: &ArrayRef,
296 child_idx: usize,
297 ) -> VortexResult<Option<ArrayRef>> {
298 RULES.evaluate(array, parent, child_idx)
299 }
300
301 fn execute_parent(
302 array: ArrayView<'_, Self>,
303 parent: &ArrayRef,
304 child_idx: usize,
305 ctx: &mut ExecutionCtx,
306 ) -> VortexResult<Option<ArrayRef>> {
307 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
308 }
309}
310
311pub(super) const LEFT_PARTS_SLOT: usize = 0;
313pub(super) const RIGHT_PARTS_SLOT: usize = 1;
315pub(super) const LP_PATCH_INDICES_SLOT: usize = 2;
317pub(super) const LP_PATCH_VALUES_SLOT: usize = 3;
319pub(super) const LP_PATCH_CHUNK_OFFSETS_SLOT: usize = 4;
321pub(super) const NUM_SLOTS: usize = 5;
322pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = [
323 "left_parts",
324 "right_parts",
325 "patch_indices",
326 "patch_values",
327 "patch_chunk_offsets",
328];
329
330#[derive(Clone, Debug)]
331pub struct ALPRDData {
332 patch_offset: Option<usize>,
333 patch_offset_within_chunk: Option<usize>,
334 left_parts_dictionary: Buffer<u16>,
335 right_bit_width: u8,
336}
337
338impl Display for ALPRDData {
339 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
340 write!(f, "right_bit_width: {}", self.right_bit_width)?;
341 if let Some(offset) = self.patch_offset {
342 write!(f, ", patch_offset: {offset}")?;
343 }
344 Ok(())
345 }
346}
347
348#[derive(Clone, Debug)]
349pub struct ALPRDDataParts {
350 pub left_parts: ArrayRef,
351 pub left_parts_patches: Option<Patches>,
352 pub left_parts_dictionary: Buffer<u16>,
353 pub right_parts: ArrayRef,
354}
355
356#[derive(Clone, Debug)]
357pub struct ALPRD;
358
359impl ALPRD {
360 pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
361
362 pub fn try_new(
363 dtype: DType,
364 left_parts: ArrayRef,
365 left_parts_dictionary: Buffer<u16>,
366 right_parts: ArrayRef,
367 right_bit_width: u8,
368 left_parts_patches: Option<Patches>,
369 ) -> VortexResult<ALPRDArray> {
370 let len = left_parts.len();
371 let left_parts_patches = ALPRDData::canonicalize_patches(&left_parts, left_parts_patches)?;
372 let slots = ALPRDData::make_slots(&left_parts, &right_parts, &left_parts_patches);
373 let data = ALPRDData::new(left_parts_dictionary, right_bit_width, left_parts_patches);
374 Array::try_from_parts(ArrayParts::new(ALPRD, dtype, len, data).with_slots(slots))
375 }
376
377 pub unsafe fn new_unchecked(
380 dtype: DType,
381 left_parts: ArrayRef,
382 left_parts_dictionary: Buffer<u16>,
383 right_parts: ArrayRef,
384 right_bit_width: u8,
385 left_parts_patches: Option<Patches>,
386 ) -> ALPRDArray {
387 let len = left_parts.len();
388 let slots = ALPRDData::make_slots(&left_parts, &right_parts, &left_parts_patches);
389 let data = unsafe {
390 ALPRDData::new_unchecked(left_parts_dictionary, right_bit_width, left_parts_patches)
391 };
392 unsafe {
393 Array::from_parts_unchecked(ArrayParts::new(ALPRD, dtype, len, data).with_slots(slots))
394 }
395 }
396}
397
398impl ALPRDData {
399 fn canonicalize_patches(
400 left_parts: &ArrayRef,
401 left_parts_patches: Option<Patches>,
402 ) -> VortexResult<Option<Patches>> {
403 left_parts_patches
404 .map(|patches| {
405 if !patches.values().all_valid()? {
406 vortex_bail!("patches must be all valid: {}", patches.values());
407 }
408 let mut patches = patches.cast_values(&left_parts.dtype().as_nonnullable())?;
411 *patches.values_mut() = patches.values().to_canonical()?.into_array();
414 Ok(patches)
415 })
416 .transpose()
417 }
418
419 pub fn new(
421 left_parts_dictionary: Buffer<u16>,
422 right_bit_width: u8,
423 left_parts_patches: Option<Patches>,
424 ) -> Self {
425 let (patch_offset, patch_offset_within_chunk) = match &left_parts_patches {
426 Some(patches) => (Some(patches.offset()), patches.offset_within_chunk()),
427 None => (None, None),
428 };
429
430 Self {
431 patch_offset,
432 patch_offset_within_chunk,
433 left_parts_dictionary,
434 right_bit_width,
435 }
436 }
437
438 pub(crate) unsafe fn new_unchecked(
441 left_parts_dictionary: Buffer<u16>,
442 right_bit_width: u8,
443 left_parts_patches: Option<Patches>,
444 ) -> Self {
445 Self::new(left_parts_dictionary, right_bit_width, left_parts_patches)
446 }
447
448 fn make_slots(
449 left_parts: &ArrayRef,
450 right_parts: &ArrayRef,
451 patches: &Option<Patches>,
452 ) -> Vec<Option<ArrayRef>> {
453 let (pi, pv, pco) = match patches {
454 Some(p) => (
455 Some(p.indices().clone()),
456 Some(p.values().clone()),
457 p.chunk_offsets().clone(),
458 ),
459 None => (None, None, None),
460 };
461 vec![
462 Some(left_parts.clone()),
463 Some(right_parts.clone()),
464 pi,
465 pv,
466 pco,
467 ]
468 }
469
470 pub fn into_parts(self, left_parts: ArrayRef, right_parts: ArrayRef) -> ALPRDDataParts {
472 ALPRDDataParts {
473 left_parts,
474 left_parts_patches: None,
475 left_parts_dictionary: self.left_parts_dictionary,
476 right_parts,
477 }
478 }
479
480 #[inline]
481 pub fn right_bit_width(&self) -> u8 {
482 self.right_bit_width
483 }
484
485 #[inline]
487 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
488 &self.left_parts_dictionary
489 }
490}
491
492fn left_parts_from_slots(slots: &[Option<ArrayRef>]) -> &ArrayRef {
493 slots[LEFT_PARTS_SLOT]
494 .as_ref()
495 .vortex_expect("ALPRDArray left_parts slot")
496}
497
498fn right_parts_from_slots(slots: &[Option<ArrayRef>]) -> &ArrayRef {
499 slots[RIGHT_PARTS_SLOT]
500 .as_ref()
501 .vortex_expect("ALPRDArray right_parts slot")
502}
503
504fn patches_from_slots(
505 slots: &[Option<ArrayRef>],
506 patch_offset: Option<usize>,
507 patch_offset_within_chunk: Option<usize>,
508 len: usize,
509) -> Option<Patches> {
510 match (&slots[LP_PATCH_INDICES_SLOT], &slots[LP_PATCH_VALUES_SLOT]) {
511 (Some(indices), Some(values)) => {
512 let patch_offset = patch_offset.vortex_expect("ALPRDArray patch slots without offset");
513 Some(unsafe {
514 Patches::new_unchecked(
515 len,
516 patch_offset,
517 indices.clone(),
518 values.clone(),
519 slots[LP_PATCH_CHUNK_OFFSETS_SLOT].clone(),
520 patch_offset_within_chunk,
521 )
522 })
523 }
524 _ => None,
525 }
526}
527
528fn validate_parts(
529 dtype: &DType,
530 len: usize,
531 left_parts: &ArrayRef,
532 right_parts: &ArrayRef,
533 left_parts_patches: Option<&Patches>,
534) -> VortexResult<()> {
535 if !dtype.is_float() {
536 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
537 }
538
539 vortex_ensure!(
540 left_parts.len() == len,
541 "left_parts len {} != outer len {len}",
542 left_parts.len(),
543 );
544 vortex_ensure!(
545 right_parts.len() == len,
546 "right_parts len {} != outer len {len}",
547 right_parts.len(),
548 );
549
550 if !left_parts.dtype().is_unsigned_int() {
551 vortex_bail!("left_parts dtype must be uint");
552 }
553 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
554 vortex_bail!(
555 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
556 dtype,
557 left_parts.dtype()
558 );
559 }
560
561 let expected_right_parts_dtype = match dtype {
562 DType::Primitive(PType::F32, _) => DType::Primitive(PType::U32, Nullability::NonNullable),
563 DType::Primitive(PType::F64, _) => DType::Primitive(PType::U64, Nullability::NonNullable),
564 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
565 };
566 vortex_ensure!(
567 right_parts.dtype() == &expected_right_parts_dtype,
568 "right_parts dtype {} does not match expected {}",
569 right_parts.dtype(),
570 expected_right_parts_dtype,
571 );
572
573 if let Some(patches) = left_parts_patches {
574 vortex_ensure!(
575 patches.array_len() == len,
576 "patches array_len {} != outer len {len}",
577 patches.array_len(),
578 );
579 vortex_ensure!(
580 patches.dtype().eq_ignore_nullability(left_parts.dtype()),
581 "patches dtype {} does not match left_parts dtype {}",
582 patches.dtype(),
583 left_parts.dtype(),
584 );
585 vortex_ensure!(
586 patches.values().all_valid()?,
587 "patches must be all valid: {}",
588 patches.values()
589 );
590 }
591
592 Ok(())
593}
594
595pub trait ALPRDArrayExt: TypedArrayRef<ALPRD> {
596 fn left_parts(&self) -> &ArrayRef {
597 left_parts_from_slots(self.as_ref().slots())
598 }
599
600 fn right_parts(&self) -> &ArrayRef {
601 right_parts_from_slots(self.as_ref().slots())
602 }
603
604 fn right_bit_width(&self) -> u8 {
605 ALPRDData::right_bit_width(self)
606 }
607
608 fn left_parts_patches(&self) -> Option<Patches> {
609 patches_from_slots(
610 self.as_ref().slots(),
611 self.patch_offset,
612 self.patch_offset_within_chunk,
613 self.as_ref().len(),
614 )
615 }
616
617 fn left_parts_dictionary(&self) -> &Buffer<u16> {
618 ALPRDData::left_parts_dictionary(self)
619 }
620}
621impl<T: TypedArrayRef<ALPRD>> ALPRDArrayExt for T {}
622
623pub trait ALPRDArrayOwnedExt {
624 fn into_data_parts(self) -> ALPRDDataParts;
625}
626
627impl ALPRDArrayOwnedExt for Array<ALPRD> {
628 fn into_data_parts(self) -> ALPRDDataParts {
629 let left_parts_patches = self.left_parts_patches();
630 let left_parts = self.left_parts().clone();
631 let right_parts = self.right_parts().clone();
632 let mut parts = ALPRDDataParts {
633 left_parts,
634 left_parts_patches: None,
635 left_parts_dictionary: self.left_parts_dictionary().clone(),
636 right_parts,
637 };
638 parts.left_parts_patches = left_parts_patches;
639 parts
640 }
641}
642
643impl ValidityChild<ALPRD> for ALPRD {
644 fn validity_child(array: ArrayView<'_, ALPRD>) -> ArrayRef {
645 array.left_parts().clone()
646 }
647}
648
649#[cfg(test)]
650mod test {
651 use prost::Message;
652 use rstest::rstest;
653 use vortex_array::ToCanonical;
654 use vortex_array::arrays::PrimitiveArray;
655 use vortex_array::assert_arrays_eq;
656 use vortex_array::dtype::PType;
657 use vortex_array::patches::PatchesMetadata;
658 use vortex_array::test_harness::check_metadata;
659
660 use super::ALPRDMetadata;
661 use crate::ALPRDFloat;
662 use crate::alp_rd;
663
664 #[rstest]
665 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
666 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
667 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
668 #[case] reals: Vec<T>,
669 #[case] seed: T,
670 ) {
671 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
672 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
674 reals[1] = None;
675 reals[5] = None;
676 reals[900] = None;
677
678 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
680
681 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
683
684 let rd_array = encoder.encode(&real_array);
685
686 let decoded = rd_array.as_array().to_primitive();
687
688 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
689 }
690
691 #[cfg_attr(miri, ignore)]
692 #[test]
693 fn test_alprd_metadata() {
694 check_metadata(
695 "alprd.metadata",
696 &ALPRDMetadata {
697 right_bit_width: u32::MAX,
698 patches: Some(PatchesMetadata::new(
699 usize::MAX,
700 usize::MAX,
701 PType::U64,
702 None,
703 None,
704 None,
705 )),
706 dict: Vec::new(),
707 left_parts_ptype: PType::U64 as i32,
708 dict_len: 8,
709 }
710 .encode_to_vec(),
711 );
712 }
713}