1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::DeserializeMetadata;
13use vortex_array::DynArray;
14use vortex_array::ExecutionCtx;
15use vortex_array::ExecutionResult;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::SerializeMetadata;
20use vortex_array::arrays::Primitive;
21use vortex_array::arrays::PrimitiveArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::dtype::Nullability;
25use vortex_array::dtype::PType;
26use vortex_array::patches::Patches;
27use vortex_array::patches::PatchesMetadata;
28use vortex_array::require_child;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::VTable;
36use vortex_array::vtable::ValidityChild;
37use vortex_array::vtable::ValidityVTableFromChild;
38use vortex_array::vtable::patches_child;
39use vortex_array::vtable::patches_child_name;
40use vortex_array::vtable::patches_nchildren;
41use vortex_buffer::Buffer;
42use vortex_error::VortexExpect;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_err;
47use vortex_error::vortex_panic;
48use vortex_session::VortexSession;
49
50use crate::alp_rd::kernel::PARENT_KERNELS;
51use crate::alp_rd::rules::RULES;
52use crate::alp_rd_decode;
53
54vtable!(ALPRD);
55
56#[derive(Clone, prost::Message)]
57pub struct ALPRDMetadata {
58 #[prost(uint32, tag = "1")]
59 right_bit_width: u32,
60 #[prost(uint32, tag = "2")]
61 dict_len: u32,
62 #[prost(uint32, repeated, tag = "3")]
63 dict: Vec<u32>,
64 #[prost(enumeration = "PType", tag = "4")]
65 left_parts_ptype: i32,
66 #[prost(message, tag = "5")]
67 patches: Option<PatchesMetadata>,
68}
69
70impl VTable for ALPRD {
71 type Array = ALPRDArray;
72
73 type Metadata = ProstMetadata<ALPRDMetadata>;
74 type OperationsVTable = Self;
75 type ValidityVTable = ValidityVTableFromChild;
76
77 fn vtable(_array: &Self::Array) -> &Self {
78 &ALPRD
79 }
80
81 fn id(&self) -> ArrayId {
82 Self::ID
83 }
84
85 fn len(array: &ALPRDArray) -> usize {
86 array.left_parts.len()
87 }
88
89 fn dtype(array: &ALPRDArray) -> &DType {
90 &array.dtype
91 }
92
93 fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
94 array.stats_set.to_ref(array.as_ref())
95 }
96
97 fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
98 array.dtype.hash(state);
99 array.left_parts.array_hash(state, precision);
100 array.left_parts_dictionary.array_hash(state, precision);
101 array.right_parts.array_hash(state, precision);
102 array.right_bit_width.hash(state);
103 array.left_parts_patches.array_hash(state, precision);
104 }
105
106 fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
107 array.dtype == other.dtype
108 && array.left_parts.array_eq(&other.left_parts, precision)
109 && array
110 .left_parts_dictionary
111 .array_eq(&other.left_parts_dictionary, precision)
112 && array.right_parts.array_eq(&other.right_parts, precision)
113 && array.right_bit_width == other.right_bit_width
114 && array
115 .left_parts_patches
116 .array_eq(&other.left_parts_patches, precision)
117 }
118
119 fn nbuffers(_array: &ALPRDArray) -> usize {
120 0
121 }
122
123 fn buffer(_array: &ALPRDArray, idx: usize) -> BufferHandle {
124 vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
125 }
126
127 fn buffer_name(_array: &ALPRDArray, _idx: usize) -> Option<String> {
128 None
129 }
130
131 fn nchildren(array: &ALPRDArray) -> usize {
132 2 + array.left_parts_patches().map_or(0, patches_nchildren)
133 }
134
135 fn child(array: &ALPRDArray, idx: usize) -> ArrayRef {
136 match idx {
137 0 => array.left_parts().clone(),
138 1 => array.right_parts().clone(),
139 _ => {
140 let patches = array
141 .left_parts_patches()
142 .unwrap_or_else(|| vortex_panic!("ALPRDArray child index {idx} out of bounds"));
143 patches_child(patches, idx - 2)
144 }
145 }
146 }
147
148 fn child_name(array: &ALPRDArray, idx: usize) -> String {
149 match idx {
150 0 => "left_parts".to_string(),
151 1 => "right_parts".to_string(),
152 _ => {
153 if array.left_parts_patches().is_none() {
154 vortex_panic!("ALPRDArray child_name index {idx} out of bounds");
155 }
156 patches_child_name(idx - 2).to_string()
157 }
158 }
159 }
160
161 fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
162 let dict = array
163 .left_parts_dictionary()
164 .iter()
165 .map(|&i| i as u32)
166 .collect::<Vec<_>>();
167
168 Ok(ProstMetadata(ALPRDMetadata {
169 right_bit_width: array.right_bit_width() as u32,
170 dict_len: array.left_parts_dictionary().len() as u32,
171 dict,
172 left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
173 patches: array
174 .left_parts_patches()
175 .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
176 .transpose()?,
177 }))
178 }
179
180 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
181 Ok(Some(metadata.serialize()))
182 }
183
184 fn deserialize(
185 bytes: &[u8],
186 _dtype: &DType,
187 _len: usize,
188 _buffers: &[BufferHandle],
189 _session: &VortexSession,
190 ) -> VortexResult<Self::Metadata> {
191 Ok(ProstMetadata(
192 <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
193 ))
194 }
195
196 fn build(
197 dtype: &DType,
198 len: usize,
199 metadata: &Self::Metadata,
200 _buffers: &[BufferHandle],
201 children: &dyn ArrayChildren,
202 ) -> VortexResult<ALPRDArray> {
203 if children.len() < 2 {
204 vortex_bail!(
205 "Expected at least 2 children for ALPRD encoding, found {}",
206 children.len()
207 );
208 }
209
210 let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
211 let left_parts = children.get(0, &left_parts_dtype, len)?;
212 let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
213 [0..metadata.0.dict_len as usize]
214 .iter()
215 .map(|&i| {
216 u16::try_from(i)
217 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
218 })
219 .try_collect()?;
220
221 let right_parts_dtype = match &dtype {
222 DType::Primitive(PType::F32, _) => {
223 DType::Primitive(PType::U32, Nullability::NonNullable)
224 }
225 DType::Primitive(PType::F64, _) => {
226 DType::Primitive(PType::U64, Nullability::NonNullable)
227 }
228 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
229 };
230 let right_parts = children.get(1, &right_parts_dtype, len)?;
231
232 let left_parts_patches = metadata
233 .0
234 .patches
235 .map(|p| {
236 let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
237 let values = children.get(3, &left_parts_dtype, p.len()?)?;
238
239 Patches::new(
240 len,
241 p.offset()?,
242 indices,
243 values,
244 None,
246 )
247 })
248 .transpose()?;
249
250 ALPRDArray::try_new(
251 dtype.clone(),
252 left_parts,
253 left_parts_dictionary,
254 right_parts,
255 u8::try_from(metadata.0.right_bit_width).map_err(|_| {
256 vortex_err!(
257 "right_bit_width {} out of u8 range",
258 metadata.0.right_bit_width
259 )
260 })?,
261 left_parts_patches,
262 )
263 }
264
265 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
266 let patches_info = array
268 .left_parts_patches
269 .as_ref()
270 .map(|p| (p.array_len(), p.offset()));
271
272 let expected_children = if patches_info.is_some() { 4 } else { 2 };
273
274 vortex_ensure!(
275 children.len() == expected_children,
276 "ALPRDArray expects {} children, got {}",
277 expected_children,
278 children.len()
279 );
280
281 let mut children_iter = children.into_iter();
282 array.left_parts = children_iter
283 .next()
284 .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
285 array.right_parts = children_iter
286 .next()
287 .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
288
289 if let Some((array_len, offset)) = patches_info {
290 let indices = children_iter
291 .next()
292 .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
293 let values = children_iter
294 .next()
295 .ok_or_else(|| vortex_err!("Expected patch values child"))?;
296
297 array.left_parts_patches = Some(Patches::new(
298 array_len, offset, indices, values,
299 None, )?);
301 }
302
303 Ok(())
304 }
305
306 fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
307 let array = require_child!(Self, array, array.left_parts(), 0 => Primitive);
308 let array = require_child!(Self, array, array.right_parts(), 1 => Primitive);
309
310 let right_bit_width = array.right_bit_width();
311 let ALPRDArrayParts {
312 left_parts,
313 right_parts,
314 left_parts_dictionary,
315 left_parts_patches,
316 dtype,
317 ..
318 } = Arc::unwrap_or_clone(array).into_parts();
319 let ptype = dtype.as_ptype();
320
321 let left_parts = left_parts
322 .try_into::<Primitive>()
323 .ok()
324 .vortex_expect("ALPRD execute: left_parts is primitive");
325 let right_parts = right_parts
326 .try_into::<Primitive>()
327 .ok()
328 .vortex_expect("ALPRD execute: right_parts is primitive");
329
330 let left_parts_dict = left_parts_dictionary;
332 let validity = left_parts.validity_mask()?;
333
334 let decoded_array = if ptype == PType::F32 {
335 PrimitiveArray::new(
337 alp_rd_decode::<f32>(
338 left_parts.into_buffer::<u16>(),
339 &left_parts_dict,
340 right_bit_width,
341 right_parts.into_buffer::<u32>(),
342 left_parts_patches,
343 ctx,
344 )?,
345 Validity::from_mask(validity, dtype.nullability()),
346 )
347 } else {
348 PrimitiveArray::new(
349 alp_rd_decode::<f64>(
350 left_parts.into_buffer::<u16>(),
351 &left_parts_dict,
352 right_bit_width,
353 right_parts.into_buffer::<u64>(),
354 left_parts_patches,
355 ctx,
356 )?,
357 Validity::from_mask(validity, dtype.nullability()),
358 )
359 };
360
361 Ok(ExecutionResult::done(decoded_array.into_array()))
362 }
363
364 fn reduce_parent(
365 array: &Self::Array,
366 parent: &ArrayRef,
367 child_idx: usize,
368 ) -> VortexResult<Option<ArrayRef>> {
369 RULES.evaluate(array, parent, child_idx)
370 }
371
372 fn execute_parent(
373 array: &Self::Array,
374 parent: &ArrayRef,
375 child_idx: usize,
376 ctx: &mut ExecutionCtx,
377 ) -> VortexResult<Option<ArrayRef>> {
378 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
379 }
380}
381
382#[derive(Clone, Debug)]
383pub struct ALPRDArray {
384 dtype: DType,
385 left_parts: ArrayRef,
386 left_parts_patches: Option<Patches>,
387 left_parts_dictionary: Buffer<u16>,
388 right_parts: ArrayRef,
389 right_bit_width: u8,
390 stats_set: ArrayStats,
391}
392
393#[derive(Clone, Debug)]
394pub struct ALPRDArrayParts {
395 pub dtype: DType,
396 pub left_parts: ArrayRef,
397 pub left_parts_patches: Option<Patches>,
398 pub left_parts_dictionary: Buffer<u16>,
399 pub right_parts: ArrayRef,
400}
401
402#[derive(Clone, Debug)]
403pub struct ALPRD;
404
405impl ALPRD {
406 pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
407}
408
409impl ALPRDArray {
410 pub fn try_new(
412 dtype: DType,
413 left_parts: ArrayRef,
414 left_parts_dictionary: Buffer<u16>,
415 right_parts: ArrayRef,
416 right_bit_width: u8,
417 left_parts_patches: Option<Patches>,
418 ) -> VortexResult<Self> {
419 if !dtype.is_float() {
420 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
421 }
422
423 let len = left_parts.len();
424 if right_parts.len() != len {
425 vortex_bail!(
426 "left_parts (len {}) and right_parts (len {}) must be of same length",
427 len,
428 right_parts.len()
429 );
430 }
431
432 if !left_parts.dtype().is_unsigned_int() {
433 vortex_bail!("left_parts dtype must be uint");
434 }
435 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
437 vortex_bail!(
438 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
439 dtype,
440 left_parts.dtype()
441 );
442 }
443
444 if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
446 vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
447 }
448
449 let left_parts_patches = left_parts_patches
450 .map(|patches| {
451 if !patches.values().all_valid()? {
452 vortex_bail!("patches must be all valid: {}", patches.values());
453 }
454 let mut patches = patches.cast_values(left_parts.dtype())?;
457 *patches.values_mut() = patches.values().to_canonical()?.into_array();
460 Ok(patches)
461 })
462 .transpose()?;
463
464 Ok(Self {
465 dtype,
466 left_parts,
467 left_parts_dictionary,
468 right_parts,
469 right_bit_width,
470 left_parts_patches,
471 stats_set: Default::default(),
472 })
473 }
474
475 pub(crate) unsafe fn new_unchecked(
478 dtype: DType,
479 left_parts: ArrayRef,
480 left_parts_dictionary: Buffer<u16>,
481 right_parts: ArrayRef,
482 right_bit_width: u8,
483 left_parts_patches: Option<Patches>,
484 ) -> Self {
485 Self {
486 dtype,
487 left_parts,
488 left_parts_patches,
489 left_parts_dictionary,
490 right_parts,
491 right_bit_width,
492 stats_set: Default::default(),
493 }
494 }
495
496 pub fn into_parts(self) -> ALPRDArrayParts {
498 ALPRDArrayParts {
499 dtype: self.dtype,
500 left_parts: self.left_parts,
501 left_parts_patches: self.left_parts_patches,
502 left_parts_dictionary: self.left_parts_dictionary,
503 right_parts: self.right_parts,
504 }
505 }
506
507 #[inline]
511 pub fn is_f32(&self) -> bool {
512 matches!(&self.dtype, DType::Primitive(PType::F32, _))
513 }
514
515 pub fn left_parts(&self) -> &ArrayRef {
520 &self.left_parts
521 }
522
523 pub fn right_parts(&self) -> &ArrayRef {
525 &self.right_parts
526 }
527
528 #[inline]
529 pub fn right_bit_width(&self) -> u8 {
530 self.right_bit_width
531 }
532
533 pub fn left_parts_patches(&self) -> Option<&Patches> {
535 self.left_parts_patches.as_ref()
536 }
537
538 #[inline]
540 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
541 &self.left_parts_dictionary
542 }
543
544 pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
545 self.left_parts_patches = patches;
546 }
547}
548
549impl ValidityChild<ALPRD> for ALPRD {
550 fn validity_child(array: &ALPRDArray) -> &ArrayRef {
551 array.left_parts()
552 }
553}
554
555#[cfg(test)]
556mod test {
557 use rstest::rstest;
558 use vortex_array::ProstMetadata;
559 use vortex_array::ToCanonical;
560 use vortex_array::arrays::PrimitiveArray;
561 use vortex_array::assert_arrays_eq;
562 use vortex_array::dtype::PType;
563 use vortex_array::patches::PatchesMetadata;
564 use vortex_array::test_harness::check_metadata;
565
566 use super::ALPRDMetadata;
567 use crate::ALPRDFloat;
568 use crate::alp_rd;
569
570 #[rstest]
571 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
572 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
573 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
574 #[case] reals: Vec<T>,
575 #[case] seed: T,
576 ) {
577 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
578 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
580 reals[1] = None;
581 reals[5] = None;
582 reals[900] = None;
583
584 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
586
587 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
589
590 let rd_array = encoder.encode(&real_array);
591
592 let decoded = rd_array.to_primitive();
593
594 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
595 }
596
597 #[cfg_attr(miri, ignore)]
598 #[test]
599 fn test_alprd_metadata() {
600 check_metadata(
601 "alprd.metadata",
602 ProstMetadata(ALPRDMetadata {
603 right_bit_width: u32::MAX,
604 patches: Some(PatchesMetadata::new(
605 usize::MAX,
606 usize::MAX,
607 PType::U64,
608 None,
609 None,
610 None,
611 )),
612 dict: Vec::new(),
613 left_parts_ptype: PType::U64 as i32,
614 dict_len: 8,
615 }),
616 );
617 }
618}