1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::ArrayEq;
9use vortex_array::ArrayHash;
10use vortex_array::ArrayRef;
11use vortex_array::DeserializeMetadata;
12use vortex_array::DynArray;
13use vortex_array::ExecutionCtx;
14use vortex_array::ExecutionStep;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::ProstMetadata;
18use vortex_array::SerializeMetadata;
19use vortex_array::arrays::PrimitiveArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::dtype::DType;
22use vortex_array::dtype::Nullability;
23use vortex_array::dtype::PType;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::serde::ArrayChildren;
27use vortex_array::stats::ArrayStats;
28use vortex_array::stats::StatsSetRef;
29use vortex_array::validity::Validity;
30use vortex_array::vtable;
31use vortex_array::vtable::ArrayId;
32use vortex_array::vtable::VTable;
33use vortex_array::vtable::ValidityChild;
34use vortex_array::vtable::ValidityVTableFromChild;
35use vortex_array::vtable::patches_child;
36use vortex_array::vtable::patches_child_name;
37use vortex_array::vtable::patches_nchildren;
38use vortex_buffer::Buffer;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_error::vortex_err;
43use vortex_error::vortex_panic;
44use vortex_mask::Mask;
45use vortex_session::VortexSession;
46
47use crate::alp_rd::kernel::PARENT_KERNELS;
48use crate::alp_rd::rules::RULES;
49use crate::alp_rd_decode;
50
51vtable!(ALPRD);
52
53#[derive(Clone, prost::Message)]
54pub struct ALPRDMetadata {
55 #[prost(uint32, tag = "1")]
56 right_bit_width: u32,
57 #[prost(uint32, tag = "2")]
58 dict_len: u32,
59 #[prost(uint32, repeated, tag = "3")]
60 dict: Vec<u32>,
61 #[prost(enumeration = "PType", tag = "4")]
62 left_parts_ptype: i32,
63 #[prost(message, tag = "5")]
64 patches: Option<PatchesMetadata>,
65}
66
67impl VTable for ALPRDVTable {
68 type Array = ALPRDArray;
69
70 type Metadata = ProstMetadata<ALPRDMetadata>;
71 type OperationsVTable = Self;
72 type ValidityVTable = ValidityVTableFromChild;
73
74 fn id(_array: &Self::Array) -> ArrayId {
75 Self::ID
76 }
77
78 fn len(array: &ALPRDArray) -> usize {
79 array.left_parts.len()
80 }
81
82 fn dtype(array: &ALPRDArray) -> &DType {
83 &array.dtype
84 }
85
86 fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
87 array.stats_set.to_ref(array.as_ref())
88 }
89
90 fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
91 array.dtype.hash(state);
92 array.left_parts.array_hash(state, precision);
93 array.left_parts_dictionary.array_hash(state, precision);
94 array.right_parts.array_hash(state, precision);
95 array.right_bit_width.hash(state);
96 array.left_parts_patches.array_hash(state, precision);
97 }
98
99 fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
100 array.dtype == other.dtype
101 && array.left_parts.array_eq(&other.left_parts, precision)
102 && array
103 .left_parts_dictionary
104 .array_eq(&other.left_parts_dictionary, precision)
105 && array.right_parts.array_eq(&other.right_parts, precision)
106 && array.right_bit_width == other.right_bit_width
107 && array
108 .left_parts_patches
109 .array_eq(&other.left_parts_patches, precision)
110 }
111
112 fn nbuffers(_array: &ALPRDArray) -> usize {
113 0
114 }
115
116 fn buffer(_array: &ALPRDArray, idx: usize) -> BufferHandle {
117 vortex_panic!("ALPRDArray buffer index {idx} out of bounds")
118 }
119
120 fn buffer_name(_array: &ALPRDArray, _idx: usize) -> Option<String> {
121 None
122 }
123
124 fn nchildren(array: &ALPRDArray) -> usize {
125 2 + array.left_parts_patches().map_or(0, patches_nchildren)
126 }
127
128 fn child(array: &ALPRDArray, idx: usize) -> ArrayRef {
129 match idx {
130 0 => array.left_parts().clone(),
131 1 => array.right_parts().clone(),
132 _ => {
133 let patches = array
134 .left_parts_patches()
135 .unwrap_or_else(|| vortex_panic!("ALPRDArray child index {idx} out of bounds"));
136 patches_child(patches, idx - 2)
137 }
138 }
139 }
140
141 fn child_name(array: &ALPRDArray, idx: usize) -> String {
142 match idx {
143 0 => "left_parts".to_string(),
144 1 => "right_parts".to_string(),
145 _ => {
146 if array.left_parts_patches().is_none() {
147 vortex_panic!("ALPRDArray child_name index {idx} out of bounds");
148 }
149 patches_child_name(idx - 2).to_string()
150 }
151 }
152 }
153
154 fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
155 let dict = array
156 .left_parts_dictionary()
157 .iter()
158 .map(|&i| i as u32)
159 .collect::<Vec<_>>();
160
161 Ok(ProstMetadata(ALPRDMetadata {
162 right_bit_width: array.right_bit_width() as u32,
163 dict_len: array.left_parts_dictionary().len() as u32,
164 dict,
165 left_parts_ptype: array.left_parts.dtype().as_ptype() as i32,
166 patches: array
167 .left_parts_patches()
168 .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
169 .transpose()?,
170 }))
171 }
172
173 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
174 Ok(Some(metadata.serialize()))
175 }
176
177 fn deserialize(
178 bytes: &[u8],
179 _dtype: &DType,
180 _len: usize,
181 _buffers: &[BufferHandle],
182 _session: &VortexSession,
183 ) -> VortexResult<Self::Metadata> {
184 Ok(ProstMetadata(
185 <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
186 ))
187 }
188
189 fn build(
190 dtype: &DType,
191 len: usize,
192 metadata: &Self::Metadata,
193 _buffers: &[BufferHandle],
194 children: &dyn ArrayChildren,
195 ) -> VortexResult<ALPRDArray> {
196 if children.len() < 2 {
197 vortex_bail!(
198 "Expected at least 2 children for ALPRD encoding, found {}",
199 children.len()
200 );
201 }
202
203 let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
204 let left_parts = children.get(0, &left_parts_dtype, len)?;
205 let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
206 [0..metadata.0.dict_len as usize]
207 .iter()
208 .map(|&i| {
209 u16::try_from(i)
210 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
211 })
212 .try_collect()?;
213
214 let right_parts_dtype = match &dtype {
215 DType::Primitive(PType::F32, _) => {
216 DType::Primitive(PType::U32, Nullability::NonNullable)
217 }
218 DType::Primitive(PType::F64, _) => {
219 DType::Primitive(PType::U64, Nullability::NonNullable)
220 }
221 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
222 };
223 let right_parts = children.get(1, &right_parts_dtype, len)?;
224
225 let left_parts_patches = metadata
226 .0
227 .patches
228 .map(|p| {
229 let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
230 let values = children.get(3, &left_parts_dtype, p.len()?)?;
231
232 Patches::new(
233 len,
234 p.offset()?,
235 indices,
236 values,
237 None,
239 )
240 })
241 .transpose()?;
242
243 ALPRDArray::try_new(
244 dtype.clone(),
245 left_parts,
246 left_parts_dictionary,
247 right_parts,
248 u8::try_from(metadata.0.right_bit_width).map_err(|_| {
249 vortex_err!(
250 "right_bit_width {} out of u8 range",
251 metadata.0.right_bit_width
252 )
253 })?,
254 left_parts_patches,
255 )
256 }
257
258 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
259 let patches_info = array
261 .left_parts_patches
262 .as_ref()
263 .map(|p| (p.array_len(), p.offset()));
264
265 let expected_children = if patches_info.is_some() { 4 } else { 2 };
266
267 vortex_ensure!(
268 children.len() == expected_children,
269 "ALPRDArray expects {} children, got {}",
270 expected_children,
271 children.len()
272 );
273
274 let mut children_iter = children.into_iter();
275 array.left_parts = children_iter
276 .next()
277 .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
278 array.right_parts = children_iter
279 .next()
280 .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
281
282 if let Some((array_len, offset)) = patches_info {
283 let indices = children_iter
284 .next()
285 .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
286 let values = children_iter
287 .next()
288 .ok_or_else(|| vortex_err!("Expected patch values child"))?;
289
290 array.left_parts_patches = Some(Patches::new(
291 array_len, offset, indices, values,
292 None, )?);
294 }
295
296 Ok(())
297 }
298
299 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
300 let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
301 let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
302
303 let left_parts_dict = array.left_parts_dictionary();
305
306 let validity = array
307 .left_parts()
308 .validity()?
309 .to_array(array.len())
310 .execute::<Mask>(ctx)?;
311
312 let decoded_array = if array.is_f32() {
313 PrimitiveArray::new(
314 alp_rd_decode::<f32>(
315 left_parts.into_buffer::<u16>(),
316 left_parts_dict,
317 array.right_bit_width,
318 right_parts.into_buffer_mut::<u32>(),
319 array.left_parts_patches(),
320 ctx,
321 )?,
322 Validity::from_mask(validity, array.dtype().nullability()),
323 )
324 } else {
325 PrimitiveArray::new(
326 alp_rd_decode::<f64>(
327 left_parts.into_buffer::<u16>(),
328 left_parts_dict,
329 array.right_bit_width,
330 right_parts.into_buffer_mut::<u64>(),
331 array.left_parts_patches(),
332 ctx,
333 )?,
334 Validity::from_mask(validity, array.dtype().nullability()),
335 )
336 };
337
338 Ok(ExecutionStep::Done(decoded_array.into_array()))
339 }
340
341 fn reduce_parent(
342 array: &Self::Array,
343 parent: &ArrayRef,
344 child_idx: usize,
345 ) -> VortexResult<Option<ArrayRef>> {
346 RULES.evaluate(array, parent, child_idx)
347 }
348
349 fn execute_parent(
350 array: &Self::Array,
351 parent: &ArrayRef,
352 child_idx: usize,
353 ctx: &mut ExecutionCtx,
354 ) -> VortexResult<Option<ArrayRef>> {
355 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
356 }
357}
358
359#[derive(Clone, Debug)]
360pub struct ALPRDArray {
361 dtype: DType,
362 left_parts: ArrayRef,
363 left_parts_patches: Option<Patches>,
364 left_parts_dictionary: Buffer<u16>,
365 right_parts: ArrayRef,
366 right_bit_width: u8,
367 stats_set: ArrayStats,
368}
369
370#[derive(Debug)]
371pub struct ALPRDVTable;
372
373impl ALPRDVTable {
374 pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
375}
376
377impl ALPRDArray {
378 pub fn try_new(
380 dtype: DType,
381 left_parts: ArrayRef,
382 left_parts_dictionary: Buffer<u16>,
383 right_parts: ArrayRef,
384 right_bit_width: u8,
385 left_parts_patches: Option<Patches>,
386 ) -> VortexResult<Self> {
387 if !dtype.is_float() {
388 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
389 }
390
391 let len = left_parts.len();
392 if right_parts.len() != len {
393 vortex_bail!(
394 "left_parts (len {}) and right_parts (len {}) must be of same length",
395 len,
396 right_parts.len()
397 );
398 }
399
400 if !left_parts.dtype().is_unsigned_int() {
401 vortex_bail!("left_parts dtype must be uint");
402 }
403 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
405 vortex_bail!(
406 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
407 dtype,
408 left_parts.dtype()
409 );
410 }
411
412 if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
414 vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
415 }
416
417 let left_parts_patches = left_parts_patches
418 .map(|patches| {
419 if !patches.values().all_valid()? {
420 vortex_bail!("patches must be all valid: {}", patches.values());
421 }
422 let mut patches = patches.cast_values(left_parts.dtype())?;
425 *patches.values_mut() = patches.values().to_canonical()?.into_array();
428 Ok(patches)
429 })
430 .transpose()?;
431
432 Ok(Self {
433 dtype,
434 left_parts,
435 left_parts_dictionary,
436 right_parts,
437 right_bit_width,
438 left_parts_patches,
439 stats_set: Default::default(),
440 })
441 }
442
443 pub(crate) unsafe fn new_unchecked(
446 dtype: DType,
447 left_parts: ArrayRef,
448 left_parts_dictionary: Buffer<u16>,
449 right_parts: ArrayRef,
450 right_bit_width: u8,
451 left_parts_patches: Option<Patches>,
452 ) -> Self {
453 Self {
454 dtype,
455 left_parts,
456 left_parts_patches,
457 left_parts_dictionary,
458 right_parts,
459 right_bit_width,
460 stats_set: Default::default(),
461 }
462 }
463
464 #[inline]
468 pub fn is_f32(&self) -> bool {
469 matches!(&self.dtype, DType::Primitive(PType::F32, _))
470 }
471
472 pub fn left_parts(&self) -> &ArrayRef {
477 &self.left_parts
478 }
479
480 pub fn right_parts(&self) -> &ArrayRef {
482 &self.right_parts
483 }
484
485 #[inline]
486 pub fn right_bit_width(&self) -> u8 {
487 self.right_bit_width
488 }
489
490 pub fn left_parts_patches(&self) -> Option<&Patches> {
492 self.left_parts_patches.as_ref()
493 }
494
495 #[inline]
497 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
498 &self.left_parts_dictionary
499 }
500
501 pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
502 self.left_parts_patches = patches;
503 }
504}
505
506impl ValidityChild<ALPRDVTable> for ALPRDVTable {
507 fn validity_child(array: &ALPRDArray) -> &ArrayRef {
508 array.left_parts()
509 }
510}
511
512#[cfg(test)]
513mod test {
514 use rstest::rstest;
515 use vortex_array::ProstMetadata;
516 use vortex_array::ToCanonical;
517 use vortex_array::arrays::PrimitiveArray;
518 use vortex_array::assert_arrays_eq;
519 use vortex_array::dtype::PType;
520 use vortex_array::patches::PatchesMetadata;
521 use vortex_array::test_harness::check_metadata;
522
523 use super::ALPRDMetadata;
524 use crate::ALPRDFloat;
525 use crate::alp_rd;
526
527 #[rstest]
528 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
529 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
530 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
531 #[case] reals: Vec<T>,
532 #[case] seed: T,
533 ) {
534 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
535 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
537 reals[1] = None;
538 reals[5] = None;
539 reals[900] = None;
540
541 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
543
544 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
546
547 let rd_array = encoder.encode(&real_array);
548
549 let decoded = rd_array.to_primitive();
550
551 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
552 }
553
554 #[cfg_attr(miri, ignore)]
555 #[test]
556 fn test_alprd_metadata() {
557 check_metadata(
558 "alprd.metadata",
559 ProstMetadata(ALPRDMetadata {
560 right_bit_width: u32::MAX,
561 patches: Some(PatchesMetadata::new(
562 usize::MAX,
563 usize::MAX,
564 PType::U64,
565 None,
566 None,
567 None,
568 )),
569 dict: Vec::new(),
570 left_parts_ptype: PType::U64 as i32,
571 dict_len: 8,
572 }),
573 );
574 }
575}