1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayRef;
14use vortex_array::DeserializeMetadata;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::SerializeMetadata;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::patches::Patches;
23use vortex_array::patches::PatchesMetadata;
24use vortex_array::serde::ArrayChildren;
25use vortex_array::stats::ArrayStats;
26use vortex_array::stats::StatsSetRef;
27use vortex_array::validity::Validity;
28use vortex_array::vtable;
29use vortex_array::vtable::ArrayId;
30use vortex_array::vtable::BaseArrayVTable;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityChild;
33use vortex_array::vtable::ValidityVTableFromChild;
34use vortex_array::vtable::VisitorVTable;
35use vortex_buffer::Buffer;
36use vortex_dtype::DType;
37use vortex_dtype::Nullability;
38use vortex_dtype::PType;
39use vortex_error::VortexExpect;
40use vortex_error::VortexResult;
41use vortex_error::vortex_bail;
42use vortex_error::vortex_ensure;
43use vortex_error::vortex_err;
44use vortex_mask::Mask;
45use vortex_session::VortexSession;
46
47use crate::alp_rd::kernel::PARENT_KERNELS;
48use crate::alp_rd::rules::RULES;
49use crate::alp_rd_decode;
50
51vtable!(ALPRD);
52
53#[derive(Clone, prost::Message)]
54pub struct ALPRDMetadata {
55 #[prost(uint32, tag = "1")]
56 right_bit_width: u32,
57 #[prost(uint32, tag = "2")]
58 dict_len: u32,
59 #[prost(uint32, repeated, tag = "3")]
60 dict: Vec<u32>,
61 #[prost(enumeration = "PType", tag = "4")]
62 left_parts_ptype: i32,
63 #[prost(message, tag = "5")]
64 patches: Option<PatchesMetadata>,
65}
66
67impl VTable for ALPRDVTable {
68 type Array = ALPRDArray;
69
70 type Metadata = ProstMetadata<ALPRDMetadata>;
71
72 type ArrayVTable = Self;
73 type OperationsVTable = Self;
74 type ValidityVTable = ValidityVTableFromChild;
75 type VisitorVTable = Self;
76
77 fn id(_array: &Self::Array) -> ArrayId {
78 Self::ID
79 }
80
81 fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
82 let dict = array
83 .left_parts_dictionary()
84 .iter()
85 .map(|&i| i as u32)
86 .collect::<Vec<_>>();
87
88 Ok(ProstMetadata(ALPRDMetadata {
89 right_bit_width: array.right_bit_width() as u32,
90 dict_len: array.left_parts_dictionary().len() as u32,
91 dict,
92 left_parts_ptype: PType::try_from(array.left_parts().dtype())
93 .vortex_expect("Must be a valid PType") as i32,
94 patches: array
95 .left_parts_patches()
96 .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
97 .transpose()?,
98 }))
99 }
100
101 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
102 Ok(Some(metadata.serialize()))
103 }
104
105 fn deserialize(
106 bytes: &[u8],
107 _dtype: &DType,
108 _len: usize,
109 _buffers: &[BufferHandle],
110 _session: &VortexSession,
111 ) -> VortexResult<Self::Metadata> {
112 Ok(ProstMetadata(
113 <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(bytes)?,
114 ))
115 }
116
117 fn build(
118 dtype: &DType,
119 len: usize,
120 metadata: &Self::Metadata,
121 _buffers: &[BufferHandle],
122 children: &dyn ArrayChildren,
123 ) -> VortexResult<ALPRDArray> {
124 if children.len() < 2 {
125 vortex_bail!(
126 "Expected at least 2 children for ALPRD encoding, found {}",
127 children.len()
128 );
129 }
130
131 let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
132 let left_parts = children.get(0, &left_parts_dtype, len)?;
133 let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
134 [0..metadata.0.dict_len as usize]
135 .iter()
136 .map(|&i| {
137 u16::try_from(i)
138 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
139 })
140 .try_collect()?;
141
142 let right_parts_dtype = match &dtype {
143 DType::Primitive(PType::F32, _) => {
144 DType::Primitive(PType::U32, Nullability::NonNullable)
145 }
146 DType::Primitive(PType::F64, _) => {
147 DType::Primitive(PType::U64, Nullability::NonNullable)
148 }
149 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
150 };
151 let right_parts = children.get(1, &right_parts_dtype, len)?;
152
153 let left_parts_patches = metadata
154 .0
155 .patches
156 .map(|p| {
157 let indices = children.get(2, &p.indices_dtype()?, p.len()?)?;
158 let values = children.get(3, &left_parts_dtype, p.len()?)?;
159
160 Patches::new(
161 len,
162 p.offset()?,
163 indices,
164 values,
165 None,
167 )
168 })
169 .transpose()?;
170
171 ALPRDArray::try_new(
172 dtype.clone(),
173 left_parts,
174 left_parts_dictionary,
175 right_parts,
176 u8::try_from(metadata.0.right_bit_width).map_err(|_| {
177 vortex_err!(
178 "right_bit_width {} out of u8 range",
179 metadata.0.right_bit_width
180 )
181 })?,
182 left_parts_patches,
183 )
184 }
185
186 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
187 let patches_info = array
189 .left_parts_patches
190 .as_ref()
191 .map(|p| (p.array_len(), p.offset()));
192
193 let expected_children = if patches_info.is_some() { 4 } else { 2 };
194
195 vortex_ensure!(
196 children.len() == expected_children,
197 "ALPRDArray expects {} children, got {}",
198 expected_children,
199 children.len()
200 );
201
202 let mut children_iter = children.into_iter();
203 array.left_parts = children_iter
204 .next()
205 .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
206 array.right_parts = children_iter
207 .next()
208 .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
209
210 if let Some((array_len, offset)) = patches_info {
211 let indices = children_iter
212 .next()
213 .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
214 let values = children_iter
215 .next()
216 .ok_or_else(|| vortex_err!("Expected patch values child"))?;
217
218 array.left_parts_patches = Some(Patches::new(
219 array_len, offset, indices, values,
220 None, )?);
222 }
223
224 Ok(())
225 }
226
227 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
228 let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
229 let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
230
231 let left_parts_dict = array.left_parts_dictionary();
233
234 let validity = array
235 .left_parts()
236 .validity()?
237 .to_array(array.len())
238 .execute::<Mask>(ctx)?;
239
240 let decoded_array = if array.is_f32() {
241 PrimitiveArray::new(
242 alp_rd_decode::<f32>(
243 left_parts.into_buffer::<u16>(),
244 left_parts_dict,
245 array.right_bit_width,
246 right_parts.into_buffer_mut::<u32>(),
247 array.left_parts_patches(),
248 ctx,
249 )?,
250 Validity::from_mask(validity, array.dtype().nullability()),
251 )
252 } else {
253 PrimitiveArray::new(
254 alp_rd_decode::<f64>(
255 left_parts.into_buffer::<u16>(),
256 left_parts_dict,
257 array.right_bit_width,
258 right_parts.into_buffer_mut::<u64>(),
259 array.left_parts_patches(),
260 ctx,
261 )?,
262 Validity::from_mask(validity, array.dtype().nullability()),
263 )
264 };
265
266 Ok(decoded_array.into_array())
267 }
268
269 fn reduce_parent(
270 array: &Self::Array,
271 parent: &ArrayRef,
272 child_idx: usize,
273 ) -> VortexResult<Option<ArrayRef>> {
274 RULES.evaluate(array, parent, child_idx)
275 }
276
277 fn execute_parent(
278 array: &Self::Array,
279 parent: &ArrayRef,
280 child_idx: usize,
281 ctx: &mut ExecutionCtx,
282 ) -> VortexResult<Option<ArrayRef>> {
283 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
284 }
285}
286
287#[derive(Clone, Debug)]
288pub struct ALPRDArray {
289 dtype: DType,
290 left_parts: ArrayRef,
291 left_parts_patches: Option<Patches>,
292 left_parts_dictionary: Buffer<u16>,
293 right_parts: ArrayRef,
294 right_bit_width: u8,
295 stats_set: ArrayStats,
296}
297
298#[derive(Debug)]
299pub struct ALPRDVTable;
300
301impl ALPRDVTable {
302 pub const ID: ArrayId = ArrayId::new_ref("vortex.alprd");
303}
304
305impl ALPRDArray {
306 pub fn try_new(
308 dtype: DType,
309 left_parts: ArrayRef,
310 left_parts_dictionary: Buffer<u16>,
311 right_parts: ArrayRef,
312 right_bit_width: u8,
313 left_parts_patches: Option<Patches>,
314 ) -> VortexResult<Self> {
315 if !dtype.is_float() {
316 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
317 }
318
319 let len = left_parts.len();
320 if right_parts.len() != len {
321 vortex_bail!(
322 "left_parts (len {}) and right_parts (len {}) must be of same length",
323 len,
324 right_parts.len()
325 );
326 }
327
328 if !left_parts.dtype().is_unsigned_int() {
329 vortex_bail!("left_parts dtype must be uint");
330 }
331 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
333 vortex_bail!(
334 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
335 dtype,
336 left_parts.dtype()
337 );
338 }
339
340 if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
342 vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
343 }
344
345 let left_parts_patches = left_parts_patches
346 .map(|patches| {
347 if !patches.values().all_valid()? {
348 vortex_bail!("patches must be all valid: {}", patches.values());
349 }
350 let mut patches = patches.cast_values(left_parts.dtype())?;
353 *patches.values_mut() = patches.values().to_canonical()?.into_array();
356 Ok(patches)
357 })
358 .transpose()?;
359
360 Ok(Self {
361 dtype,
362 left_parts,
363 left_parts_dictionary,
364 right_parts,
365 right_bit_width,
366 left_parts_patches,
367 stats_set: Default::default(),
368 })
369 }
370
371 pub(crate) unsafe fn new_unchecked(
374 dtype: DType,
375 left_parts: ArrayRef,
376 left_parts_dictionary: Buffer<u16>,
377 right_parts: ArrayRef,
378 right_bit_width: u8,
379 left_parts_patches: Option<Patches>,
380 ) -> Self {
381 Self {
382 dtype,
383 left_parts,
384 left_parts_patches,
385 left_parts_dictionary,
386 right_parts,
387 right_bit_width,
388 stats_set: Default::default(),
389 }
390 }
391
392 #[inline]
396 pub fn is_f32(&self) -> bool {
397 matches!(&self.dtype, DType::Primitive(PType::F32, _))
398 }
399
400 pub fn left_parts(&self) -> &ArrayRef {
405 &self.left_parts
406 }
407
408 pub fn right_parts(&self) -> &ArrayRef {
410 &self.right_parts
411 }
412
413 #[inline]
414 pub fn right_bit_width(&self) -> u8 {
415 self.right_bit_width
416 }
417
418 pub fn left_parts_patches(&self) -> Option<&Patches> {
420 self.left_parts_patches.as_ref()
421 }
422
423 #[inline]
425 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
426 &self.left_parts_dictionary
427 }
428
429 pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
430 self.left_parts_patches = patches;
431 }
432}
433
434impl ValidityChild<ALPRDVTable> for ALPRDVTable {
435 fn validity_child(array: &ALPRDArray) -> &ArrayRef {
436 array.left_parts()
437 }
438}
439
440impl BaseArrayVTable<ALPRDVTable> for ALPRDVTable {
441 fn len(array: &ALPRDArray) -> usize {
442 array.left_parts.len()
443 }
444
445 fn dtype(array: &ALPRDArray) -> &DType {
446 &array.dtype
447 }
448
449 fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
450 array.stats_set.to_ref(array.as_ref())
451 }
452
453 fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
454 array.dtype.hash(state);
455 array.left_parts.array_hash(state, precision);
456 array.left_parts_dictionary.array_hash(state, precision);
457 array.right_parts.array_hash(state, precision);
458 array.right_bit_width.hash(state);
459 array.left_parts_patches.array_hash(state, precision);
460 }
461
462 fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
463 array.dtype == other.dtype
464 && array.left_parts.array_eq(&other.left_parts, precision)
465 && array
466 .left_parts_dictionary
467 .array_eq(&other.left_parts_dictionary, precision)
468 && array.right_parts.array_eq(&other.right_parts, precision)
469 && array.right_bit_width == other.right_bit_width
470 && array
471 .left_parts_patches
472 .array_eq(&other.left_parts_patches, precision)
473 }
474}
475
476impl VisitorVTable<ALPRDVTable> for ALPRDVTable {
477 fn visit_buffers(_array: &ALPRDArray, _visitor: &mut dyn ArrayBufferVisitor) {}
478
479 fn nbuffers(_array: &ALPRDArray) -> usize {
480 0
481 }
482
483 fn visit_children(array: &ALPRDArray, visitor: &mut dyn ArrayChildVisitor) {
484 visitor.visit_child("left_parts", array.left_parts());
485 visitor.visit_child("right_parts", array.right_parts());
486 if let Some(patches) = array.left_parts_patches() {
487 visitor.visit_patches(patches);
488 }
489 }
490
491 fn nchildren(array: &ALPRDArray) -> usize {
492 2 + array
494 .left_parts_patches()
495 .map_or(0, |p| 2 + p.chunk_offsets().is_some() as usize)
496 }
497}
498
499#[cfg(test)]
500mod test {
501 use rstest::rstest;
502 use vortex_array::ProstMetadata;
503 use vortex_array::ToCanonical;
504 use vortex_array::arrays::PrimitiveArray;
505 use vortex_array::assert_arrays_eq;
506 use vortex_array::patches::PatchesMetadata;
507 use vortex_array::test_harness::check_metadata;
508 use vortex_dtype::PType;
509
510 use super::ALPRDMetadata;
511 use crate::ALPRDFloat;
512 use crate::alp_rd;
513
514 #[rstest]
515 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
516 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
517 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
518 #[case] reals: Vec<T>,
519 #[case] seed: T,
520 ) {
521 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
522 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
524 reals[1] = None;
525 reals[5] = None;
526 reals[900] = None;
527
528 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
530
531 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
533
534 let rd_array = encoder.encode(&real_array);
535
536 let decoded = rd_array.to_primitive();
537
538 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
539 }
540
541 #[cfg_attr(miri, ignore)]
542 #[test]
543 fn test_alprd_metadata() {
544 check_metadata(
545 "alprd.metadata",
546 ProstMetadata(ALPRDMetadata {
547 right_bit_width: u32::MAX,
548 patches: Some(PatchesMetadata::new(
549 usize::MAX,
550 usize::MAX,
551 PType::U64,
552 None,
553 None,
554 None,
555 )),
556 dict: Vec::new(),
557 left_parts_ptype: PType::U64 as i32,
558 dict_len: 8,
559 }),
560 );
561 }
562}