1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use itertools::Itertools;
11use prost::Message;
12use vortex_array::Array;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayId;
16use vortex_array::ArrayParts;
17use vortex_array::ArrayRef;
18use vortex_array::ArrayView;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionResult;
21use vortex_array::IntoArray;
22use vortex_array::LEGACY_SESSION;
23use vortex_array::Precision;
24use vortex_array::TypedArrayRef;
25use vortex_array::VortexSessionExecute;
26use vortex_array::arrays::Primitive;
27use vortex_array::arrays::PrimitiveArray;
28use vortex_array::buffer::BufferHandle;
29use vortex_array::dtype::DType;
30use vortex_array::dtype::Nullability;
31use vortex_array::dtype::PType;
32use vortex_array::patches::Patches;
33use vortex_array::patches::PatchesMetadata;
34use vortex_array::require_child;
35use vortex_array::require_patches;
36use vortex_array::serde::ArrayChildren;
37use vortex_array::validity::Validity;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityChild;
40use vortex_array::vtable::ValidityVTableFromChild;
41use vortex_buffer::Buffer;
42use vortex_error::VortexExpect;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_err;
47use vortex_error::vortex_panic;
48use vortex_session::VortexSession;
49use vortex_session::registry::CachedId;
50
51use crate::alp_rd::kernel::PARENT_KERNELS;
52use crate::alp_rd::rules::RULES;
53use crate::alp_rd_decode;
54
55pub type ALPRDArray = Array<ALPRD>;
57
58#[derive(Clone, prost::Message)]
59pub struct ALPRDMetadata {
60 #[prost(uint32, tag = "1")]
61 right_bit_width: u32,
62 #[prost(uint32, tag = "2")]
63 dict_len: u32,
64 #[prost(uint32, repeated, tag = "3")]
65 dict: Vec<u32>,
66 #[prost(enumeration = "PType", tag = "4")]
67 left_parts_ptype: i32,
68 #[prost(message, tag = "5")]
69 patches: Option<PatchesMetadata>,
70}
71
72impl ArrayHash for ALPRDData {
73 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
74 self.left_parts_dictionary.array_hash(state, precision);
75 self.right_bit_width.hash(state);
76 self.patch_offset.hash(state);
77 self.patch_offset_within_chunk.hash(state);
78 }
79}
80
81impl ArrayEq for ALPRDData {
82 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
83 self.left_parts_dictionary
84 .array_eq(&other.left_parts_dictionary, precision)
85 && self.right_bit_width == other.right_bit_width
86 && self.patch_offset == other.patch_offset
87 && self.patch_offset_within_chunk == other.patch_offset_within_chunk
88 }
89}
90
91impl VTable for ALPRD {
92 type ArrayData = ALPRDData;
93
94 type OperationsVTable = Self;
95 type ValidityVTable = ValidityVTableFromChild;
96
97 fn id(&self) -> ArrayId {
98 static ID: CachedId = CachedId::new("vortex.alprd");
99 *ID
100 }
101
102 fn validate(
103 &self,
104 data: &ALPRDData,
105 dtype: &DType,
106 len: usize,
107 slots: &[Option<ArrayRef>],
108 ) -> VortexResult<()> {
109 validate_parts(
110 dtype,
111 len,
112 left_parts_from_slots(slots),
113 right_parts_from_slots(slots),
114 patches_from_slots(
115 slots,
116 data.patch_offset,
117 data.patch_offset_within_chunk,
118 len,
119 )
120 .as_ref(),
121 )
122 }
123
124 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
125 0
126 }
127
128 fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
129 vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
130 }
131
132 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
133 None
134 }
135
136 fn serialize(
137 array: ArrayView<'_, Self>,
138 _session: &VortexSession,
139 ) -> VortexResult<Option<Vec<u8>>> {
140 let dict = array
141 .left_parts_dictionary()
142 .iter()
143 .map(|&i| i as u32)
144 .collect::<Vec<_>>();
145
146 Ok(Some(
147 ALPRDMetadata {
148 right_bit_width: array.right_bit_width() as u32,
149 dict_len: array.left_parts_dictionary().len() as u32,
150 dict,
151 left_parts_ptype: array.left_parts().dtype().as_ptype() as i32,
152 patches: array
153 .left_parts_patches()
154 .map(|p| p.to_metadata(array.len(), p.dtype()))
155 .transpose()?,
156 }
157 .encode_to_vec(),
158 ))
159 }
160
161 fn deserialize(
162 &self,
163 dtype: &DType,
164 len: usize,
165 metadata: &[u8],
166 _buffers: &[BufferHandle],
167 children: &dyn ArrayChildren,
168 _session: &VortexSession,
169 ) -> VortexResult<ArrayParts<Self>> {
170 let metadata = ALPRDMetadata::decode(metadata)?;
171 if children.len() < 2 {
172 vortex_bail!(
173 "Expected at least 2 children for ALPRD encoding, found {}",
174 children.len()
175 );
176 }
177
178 let left_parts_dtype = DType::Primitive(metadata.left_parts_ptype(), dtype.nullability());
179 let left_parts = children.get(0, &left_parts_dtype, len)?;
180 let left_parts_dictionary: Buffer<u16> = metadata.dict.as_slice()
181 [0..metadata.dict_len as usize]
182 .iter()
183 .map(|&i| {
184 u16::try_from(i)
185 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
186 })
187 .try_collect()?;
188
189 let right_parts_dtype = match &dtype {
190 DType::Primitive(PType::F32, _) => {
191 DType::Primitive(PType::U32, Nullability::NonNullable)
192 }
193 DType::Primitive(PType::F64, _) => {
194 DType::Primitive(PType::U64, Nullability::NonNullable)
195 }
196 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
197 };
198 let right_parts = children.get(1, &right_parts_dtype, len)?;
199
200 let left_parts_patches = metadata
201 .patches
202 .map(|p| {
203 let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
204 let values = children.get(3, &left_parts_dtype.as_nonnullable(), p.len()?)?;
205
206 Patches::new(
207 len,
208 p.offset()?,
209 indices,
210 values,
211 None,
213 )
214 })
215 .transpose()?;
216 let left_parts_patches = ALPRDData::canonicalize_patches(&left_parts, left_parts_patches)?;
217 let slots = ALPRDData::make_slots(&left_parts, &right_parts, &left_parts_patches);
218 let data = ALPRDData::new(
219 left_parts_dictionary,
220 u8::try_from(metadata.right_bit_width).map_err(|_| {
221 vortex_err!(
222 "right_bit_width {} out of u8 range",
223 metadata.right_bit_width
224 )
225 })?,
226 left_parts_patches,
227 );
228 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
229 }
230
231 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
232 SLOT_NAMES[idx].to_string()
233 }
234
235 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
236 let array = require_child!(array, array.left_parts(), 0 => Primitive);
237 let array = require_child!(array, array.right_parts(), 1 => Primitive);
238 require_patches!(
239 array,
240 LP_PATCH_INDICES_SLOT,
241 LP_PATCH_VALUES_SLOT,
242 LP_PATCH_CHUNK_OFFSETS_SLOT
243 );
244
245 let dtype = array.dtype().clone();
246 let right_bit_width = array.right_bit_width();
247 let ALPRDDataParts {
248 left_parts,
249 right_parts,
250 left_parts_dictionary,
251 left_parts_patches,
252 } = ALPRDArrayOwnedExt::into_data_parts(array);
253 let ptype = dtype.as_ptype();
254
255 let left_parts = left_parts
256 .try_downcast::<Primitive>()
257 .ok()
258 .vortex_expect("ALPRD execute: left_parts is primitive");
259 let right_parts = right_parts
260 .try_downcast::<Primitive>()
261 .ok()
262 .vortex_expect("ALPRD execute: right_parts is primitive");
263
264 let left_parts_dict = left_parts_dictionary;
266 let validity = left_parts
267 .as_ref()
268 .validity()?
269 .to_mask(left_parts.as_ref().len(), ctx)?;
270
271 let decoded_array = if ptype == PType::F32 {
272 PrimitiveArray::new(
273 alp_rd_decode::<f32>(
274 left_parts.into_buffer_mut::<u16>(),
275 &left_parts_dict,
276 right_bit_width,
277 right_parts.into_buffer_mut::<u32>(),
278 left_parts_patches,
279 ctx,
280 )?,
281 Validity::from_mask(validity, dtype.nullability()),
282 )
283 } else {
284 PrimitiveArray::new(
285 alp_rd_decode::<f64>(
286 left_parts.into_buffer_mut::<u16>(),
287 &left_parts_dict,
288 right_bit_width,
289 right_parts.into_buffer_mut::<u64>(),
290 left_parts_patches,
291 ctx,
292 )?,
293 Validity::from_mask(validity, dtype.nullability()),
294 )
295 };
296
297 Ok(ExecutionResult::done(decoded_array.into_array()))
298 }
299
300 fn reduce_parent(
301 array: ArrayView<'_, Self>,
302 parent: &ArrayRef,
303 child_idx: usize,
304 ) -> VortexResult<Option<ArrayRef>> {
305 RULES.evaluate(array, parent, child_idx)
306 }
307
308 fn execute_parent(
309 array: ArrayView<'_, Self>,
310 parent: &ArrayRef,
311 child_idx: usize,
312 ctx: &mut ExecutionCtx,
313 ) -> VortexResult<Option<ArrayRef>> {
314 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
315 }
316}
317
318pub(super) const LEFT_PARTS_SLOT: usize = 0;
320pub(super) const RIGHT_PARTS_SLOT: usize = 1;
322pub(super) const LP_PATCH_INDICES_SLOT: usize = 2;
324pub(super) const LP_PATCH_VALUES_SLOT: usize = 3;
326pub(super) const LP_PATCH_CHUNK_OFFSETS_SLOT: usize = 4;
328pub(super) const NUM_SLOTS: usize = 5;
329pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = [
330 "left_parts",
331 "right_parts",
332 "patch_indices",
333 "patch_values",
334 "patch_chunk_offsets",
335];
336
337#[derive(Clone, Debug)]
338pub struct ALPRDData {
339 patch_offset: Option<usize>,
340 patch_offset_within_chunk: Option<usize>,
341 left_parts_dictionary: Buffer<u16>,
342 right_bit_width: u8,
343}
344
345impl Display for ALPRDData {
346 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
347 write!(f, "right_bit_width: {}", self.right_bit_width)?;
348 if let Some(offset) = self.patch_offset {
349 write!(f, ", patch_offset: {offset}")?;
350 }
351 Ok(())
352 }
353}
354
355#[derive(Clone, Debug)]
356pub struct ALPRDDataParts {
357 pub left_parts: ArrayRef,
358 pub left_parts_patches: Option<Patches>,
359 pub left_parts_dictionary: Buffer<u16>,
360 pub right_parts: ArrayRef,
361}
362
363#[derive(Clone, Debug)]
364pub struct ALPRD;
365
366impl ALPRD {
367 pub fn try_new(
368 dtype: DType,
369 left_parts: ArrayRef,
370 left_parts_dictionary: Buffer<u16>,
371 right_parts: ArrayRef,
372 right_bit_width: u8,
373 left_parts_patches: Option<Patches>,
374 ) -> VortexResult<ALPRDArray> {
375 let len = left_parts.len();
376 let left_parts_patches = ALPRDData::canonicalize_patches(&left_parts, left_parts_patches)?;
377 let slots = ALPRDData::make_slots(&left_parts, &right_parts, &left_parts_patches);
378 let data = ALPRDData::new(left_parts_dictionary, right_bit_width, left_parts_patches);
379 Array::try_from_parts(ArrayParts::new(ALPRD, dtype, len, data).with_slots(slots))
380 }
381
382 pub unsafe fn new_unchecked(
385 dtype: DType,
386 left_parts: ArrayRef,
387 left_parts_dictionary: Buffer<u16>,
388 right_parts: ArrayRef,
389 right_bit_width: u8,
390 left_parts_patches: Option<Patches>,
391 ) -> ALPRDArray {
392 let len = left_parts.len();
393 let slots = ALPRDData::make_slots(&left_parts, &right_parts, &left_parts_patches);
394 let data = unsafe {
395 ALPRDData::new_unchecked(left_parts_dictionary, right_bit_width, left_parts_patches)
396 };
397 unsafe {
398 Array::from_parts_unchecked(ArrayParts::new(ALPRD, dtype, len, data).with_slots(slots))
399 }
400 }
401}
402
403impl ALPRDData {
404 fn canonicalize_patches(
405 left_parts: &ArrayRef,
406 left_parts_patches: Option<Patches>,
407 ) -> VortexResult<Option<Patches>> {
408 left_parts_patches
409 .map(|patches| {
410 if !patches
411 .values()
412 .all_valid(&mut LEGACY_SESSION.create_execution_ctx())?
413 {
414 vortex_bail!("patches must be all valid: {}", patches.values());
415 }
416 let mut patches = patches.cast_values(&left_parts.dtype().as_nonnullable())?;
419 *patches.values_mut() = patches.values().to_canonical()?.into_array();
422 Ok(patches)
423 })
424 .transpose()
425 }
426
427 pub fn new(
429 left_parts_dictionary: Buffer<u16>,
430 right_bit_width: u8,
431 left_parts_patches: Option<Patches>,
432 ) -> Self {
433 let (patch_offset, patch_offset_within_chunk) = match &left_parts_patches {
434 Some(patches) => (Some(patches.offset()), patches.offset_within_chunk()),
435 None => (None, None),
436 };
437
438 Self {
439 patch_offset,
440 patch_offset_within_chunk,
441 left_parts_dictionary,
442 right_bit_width,
443 }
444 }
445
446 pub(crate) unsafe fn new_unchecked(
449 left_parts_dictionary: Buffer<u16>,
450 right_bit_width: u8,
451 left_parts_patches: Option<Patches>,
452 ) -> Self {
453 Self::new(left_parts_dictionary, right_bit_width, left_parts_patches)
454 }
455
456 fn make_slots(
457 left_parts: &ArrayRef,
458 right_parts: &ArrayRef,
459 patches: &Option<Patches>,
460 ) -> Vec<Option<ArrayRef>> {
461 let (pi, pv, pco) = match patches {
462 Some(p) => (
463 Some(p.indices().clone()),
464 Some(p.values().clone()),
465 p.chunk_offsets().clone(),
466 ),
467 None => (None, None, None),
468 };
469 vec![
470 Some(left_parts.clone()),
471 Some(right_parts.clone()),
472 pi,
473 pv,
474 pco,
475 ]
476 }
477
478 pub fn into_parts(self, left_parts: ArrayRef, right_parts: ArrayRef) -> ALPRDDataParts {
480 ALPRDDataParts {
481 left_parts,
482 left_parts_patches: None,
483 left_parts_dictionary: self.left_parts_dictionary,
484 right_parts,
485 }
486 }
487
488 #[inline]
489 pub fn right_bit_width(&self) -> u8 {
490 self.right_bit_width
491 }
492
493 #[inline]
495 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
496 &self.left_parts_dictionary
497 }
498}
499
500fn left_parts_from_slots(slots: &[Option<ArrayRef>]) -> &ArrayRef {
501 slots[LEFT_PARTS_SLOT]
502 .as_ref()
503 .vortex_expect("ALPRDArray left_parts slot")
504}
505
506fn right_parts_from_slots(slots: &[Option<ArrayRef>]) -> &ArrayRef {
507 slots[RIGHT_PARTS_SLOT]
508 .as_ref()
509 .vortex_expect("ALPRDArray right_parts slot")
510}
511
512fn patches_from_slots(
513 slots: &[Option<ArrayRef>],
514 patch_offset: Option<usize>,
515 patch_offset_within_chunk: Option<usize>,
516 len: usize,
517) -> Option<Patches> {
518 match (&slots[LP_PATCH_INDICES_SLOT], &slots[LP_PATCH_VALUES_SLOT]) {
519 (Some(indices), Some(values)) => {
520 let patch_offset = patch_offset.vortex_expect("ALPRDArray patch slots without offset");
521 Some(unsafe {
522 Patches::new_unchecked(
523 len,
524 patch_offset,
525 indices.clone(),
526 values.clone(),
527 slots[LP_PATCH_CHUNK_OFFSETS_SLOT].clone(),
528 patch_offset_within_chunk,
529 )
530 })
531 }
532 _ => None,
533 }
534}
535
536fn validate_parts(
537 dtype: &DType,
538 len: usize,
539 left_parts: &ArrayRef,
540 right_parts: &ArrayRef,
541 left_parts_patches: Option<&Patches>,
542) -> VortexResult<()> {
543 if !dtype.is_float() {
544 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
545 }
546
547 vortex_ensure!(
548 left_parts.len() == len,
549 "left_parts len {} != outer len {len}",
550 left_parts.len(),
551 );
552 vortex_ensure!(
553 right_parts.len() == len,
554 "right_parts len {} != outer len {len}",
555 right_parts.len(),
556 );
557
558 if !left_parts.dtype().is_unsigned_int() {
559 vortex_bail!("left_parts dtype must be uint");
560 }
561 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
562 vortex_bail!(
563 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
564 dtype,
565 left_parts.dtype()
566 );
567 }
568
569 let expected_right_parts_dtype = match dtype {
570 DType::Primitive(PType::F32, _) => DType::Primitive(PType::U32, Nullability::NonNullable),
571 DType::Primitive(PType::F64, _) => DType::Primitive(PType::U64, Nullability::NonNullable),
572 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
573 };
574 vortex_ensure!(
575 right_parts.dtype() == &expected_right_parts_dtype,
576 "right_parts dtype {} does not match expected {}",
577 right_parts.dtype(),
578 expected_right_parts_dtype,
579 );
580
581 if let Some(patches) = left_parts_patches {
582 vortex_ensure!(
583 patches.array_len() == len,
584 "patches array_len {} != outer len {len}",
585 patches.array_len(),
586 );
587 vortex_ensure!(
588 patches.dtype().eq_ignore_nullability(left_parts.dtype()),
589 "patches dtype {} does not match left_parts dtype {}",
590 patches.dtype(),
591 left_parts.dtype(),
592 );
593 vortex_ensure!(
594 patches
595 .values()
596 .all_valid(&mut LEGACY_SESSION.create_execution_ctx())?,
597 "patches must be all valid: {}",
598 patches.values()
599 );
600 }
601
602 Ok(())
603}
604
605pub trait ALPRDArrayExt: TypedArrayRef<ALPRD> {
606 fn left_parts(&self) -> &ArrayRef {
607 left_parts_from_slots(self.as_ref().slots())
608 }
609
610 fn right_parts(&self) -> &ArrayRef {
611 right_parts_from_slots(self.as_ref().slots())
612 }
613
614 fn right_bit_width(&self) -> u8 {
615 ALPRDData::right_bit_width(self)
616 }
617
618 fn left_parts_patches(&self) -> Option<Patches> {
619 patches_from_slots(
620 self.as_ref().slots(),
621 self.patch_offset,
622 self.patch_offset_within_chunk,
623 self.as_ref().len(),
624 )
625 }
626
627 fn left_parts_dictionary(&self) -> &Buffer<u16> {
628 ALPRDData::left_parts_dictionary(self)
629 }
630}
631impl<T: TypedArrayRef<ALPRD>> ALPRDArrayExt for T {}
632
633pub trait ALPRDArrayOwnedExt {
634 fn into_data_parts(self) -> ALPRDDataParts;
635}
636
637impl ALPRDArrayOwnedExt for Array<ALPRD> {
638 fn into_data_parts(self) -> ALPRDDataParts {
639 let left_parts_patches = self.left_parts_patches();
640 let left_parts = self.left_parts().clone();
641 let right_parts = self.right_parts().clone();
642 let mut parts = ALPRDDataParts {
643 left_parts,
644 left_parts_patches: None,
645 left_parts_dictionary: self.left_parts_dictionary().clone(),
646 right_parts,
647 };
648 parts.left_parts_patches = left_parts_patches;
649 parts
650 }
651}
652
653impl ValidityChild<ALPRD> for ALPRD {
654 fn validity_child(array: ArrayView<'_, ALPRD>) -> ArrayRef {
655 array.left_parts().clone()
656 }
657}
658
659#[cfg(test)]
660mod test {
661 use prost::Message;
662 use rstest::rstest;
663 use vortex_array::ToCanonical;
664 use vortex_array::arrays::PrimitiveArray;
665 use vortex_array::assert_arrays_eq;
666 use vortex_array::dtype::PType;
667 use vortex_array::patches::PatchesMetadata;
668 use vortex_array::test_harness::check_metadata;
669
670 use super::ALPRDMetadata;
671 use crate::ALPRDFloat;
672 use crate::alp_rd;
673
674 #[rstest]
675 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
676 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
677 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
678 #[case] reals: Vec<T>,
679 #[case] seed: T,
680 ) {
681 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
682 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
684 reals[1] = None;
685 reals[5] = None;
686 reals[900] = None;
687
688 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
690
691 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
693
694 let rd_array = encoder.encode(real_array.as_view());
695
696 let decoded = rd_array.as_array().to_primitive();
697
698 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
699 }
700
701 #[cfg_attr(miri, ignore)]
702 #[test]
703 fn test_alprd_metadata() {
704 check_metadata(
705 "alprd.metadata",
706 &ALPRDMetadata {
707 right_bit_width: u32::MAX,
708 patches: Some(PatchesMetadata::new(
709 usize::MAX,
710 usize::MAX,
711 PType::U64,
712 None,
713 None,
714 None,
715 )),
716 dict: Vec::new(),
717 left_parts_ptype: PType::U64 as i32,
718 dict_len: 8,
719 }
720 .encode_to_vec(),
721 );
722 }
723}