1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::Array;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayRef;
14use vortex_array::Canonical;
15use vortex_array::DeserializeMetadata;
16use vortex_array::Precision;
17use vortex_array::ProstMetadata;
18use vortex_array::SerializeMetadata;
19use vortex_array::ToCanonical;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::patches::Patches;
23use vortex_array::patches::PatchesMetadata;
24use vortex_array::serde::ArrayChildren;
25use vortex_array::stats::ArrayStats;
26use vortex_array::stats::StatsSetRef;
27use vortex_array::validity::Validity;
28use vortex_array::vtable;
29use vortex_array::vtable::ArrayId;
30use vortex_array::vtable::ArrayVTable;
31use vortex_array::vtable::ArrayVTableExt;
32use vortex_array::vtable::BaseArrayVTable;
33use vortex_array::vtable::CanonicalVTable;
34use vortex_array::vtable::EncodeVTable;
35use vortex_array::vtable::NotSupported;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityChild;
38use vortex_array::vtable::ValidityVTableFromChild;
39use vortex_array::vtable::VisitorVTable;
40use vortex_buffer::Buffer;
41use vortex_dtype::DType;
42use vortex_dtype::Nullability;
43use vortex_dtype::PType;
44use vortex_error::VortexError;
45use vortex_error::VortexExpect;
46use vortex_error::VortexResult;
47use vortex_error::vortex_bail;
48use vortex_error::vortex_ensure;
49use vortex_error::vortex_err;
50
51use crate::alp_rd::alp_rd_decode;
52
53vtable!(ALPRD);
54
55#[derive(Clone, prost::Message)]
56pub struct ALPRDMetadata {
57 #[prost(uint32, tag = "1")]
58 right_bit_width: u32,
59 #[prost(uint32, tag = "2")]
60 dict_len: u32,
61 #[prost(uint32, repeated, tag = "3")]
62 dict: Vec<u32>,
63 #[prost(enumeration = "PType", tag = "4")]
64 left_parts_ptype: i32,
65 #[prost(message, tag = "5")]
66 patches: Option<PatchesMetadata>,
67}
68
69impl VTable for ALPRDVTable {
70 type Array = ALPRDArray;
71
72 type Metadata = ProstMetadata<ALPRDMetadata>;
73
74 type ArrayVTable = Self;
75 type CanonicalVTable = Self;
76 type OperationsVTable = Self;
77 type ValidityVTable = ValidityVTableFromChild;
78 type VisitorVTable = Self;
79 type ComputeVTable = NotSupported;
80 type EncodeVTable = Self;
81
82 fn id(&self) -> ArrayId {
83 ArrayId::new_ref("vortex.alprd")
84 }
85
86 fn encoding(_array: &Self::Array) -> ArrayVTable {
87 ALPRDVTable.as_vtable()
88 }
89
90 fn metadata(array: &ALPRDArray) -> VortexResult<Self::Metadata> {
91 let dict = array
92 .left_parts_dictionary()
93 .iter()
94 .map(|&i| i as u32)
95 .collect::<Vec<_>>();
96
97 Ok(ProstMetadata(ALPRDMetadata {
98 right_bit_width: array.right_bit_width() as u32,
99 dict_len: array.left_parts_dictionary().len() as u32,
100 dict,
101 left_parts_ptype: PType::try_from(array.left_parts().dtype())
102 .vortex_expect("Must be a valid PType") as i32,
103 patches: array
104 .left_parts_patches()
105 .map(|p| p.to_metadata(array.len(), array.left_parts().dtype()))
106 .transpose()?,
107 }))
108 }
109
110 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
111 Ok(Some(metadata.serialize()))
112 }
113
114 fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
115 Ok(ProstMetadata(
116 <ProstMetadata<ALPRDMetadata> as DeserializeMetadata>::deserialize(buffer)?,
117 ))
118 }
119
120 fn build(
121 &self,
122 dtype: &DType,
123 len: usize,
124 metadata: &Self::Metadata,
125 _buffers: &[BufferHandle],
126 children: &dyn ArrayChildren,
127 ) -> VortexResult<ALPRDArray> {
128 if children.len() < 2 {
129 vortex_bail!(
130 "Expected at least 2 children for ALPRD encoding, found {}",
131 children.len()
132 );
133 }
134
135 let left_parts_dtype = DType::Primitive(metadata.0.left_parts_ptype(), dtype.nullability());
136 let left_parts = children.get(0, &left_parts_dtype, len)?;
137 let left_parts_dictionary: Buffer<u16> = metadata.0.dict.as_slice()
138 [0..metadata.0.dict_len as usize]
139 .iter()
140 .map(|&i| {
141 u16::try_from(i)
142 .map_err(|_| vortex_err!("left_parts_dictionary code {i} does not fit in u16"))
143 })
144 .try_collect()?;
145
146 let right_parts_dtype = match &dtype {
147 DType::Primitive(PType::F32, _) => {
148 DType::Primitive(PType::U32, Nullability::NonNullable)
149 }
150 DType::Primitive(PType::F64, _) => {
151 DType::Primitive(PType::U64, Nullability::NonNullable)
152 }
153 _ => vortex_bail!("Expected f32 or f64 dtype, got {:?}", dtype),
154 };
155 let right_parts = children.get(1, &right_parts_dtype, len)?;
156
157 let left_parts_patches = metadata
158 .0
159 .patches
160 .map(|p| {
161 let indices = children.get(2, &p.indices_dtype(), p.len())?;
162 let values = children.get(3, &left_parts_dtype, p.len())?;
163
164 Ok::<_, VortexError>(Patches::new(
165 len,
166 p.offset(),
167 indices,
168 values,
169 None,
171 ))
172 })
173 .transpose()?;
174
175 ALPRDArray::try_new(
176 dtype.clone(),
177 left_parts,
178 left_parts_dictionary,
179 right_parts,
180 u8::try_from(metadata.0.right_bit_width).map_err(|_| {
181 vortex_err!(
182 "right_bit_width {} out of u8 range",
183 metadata.0.right_bit_width
184 )
185 })?,
186 left_parts_patches,
187 )
188 }
189
190 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
191 let patches_info = array
193 .left_parts_patches
194 .as_ref()
195 .map(|p| (p.array_len(), p.offset()));
196
197 let expected_children = if patches_info.is_some() { 4 } else { 2 };
198
199 vortex_ensure!(
200 children.len() == expected_children,
201 "ALPRDArray expects {} children, got {}",
202 expected_children,
203 children.len()
204 );
205
206 let mut children_iter = children.into_iter();
207 array.left_parts = children_iter
208 .next()
209 .ok_or_else(|| vortex_err!("Expected left_parts child"))?;
210 array.right_parts = children_iter
211 .next()
212 .ok_or_else(|| vortex_err!("Expected right_parts child"))?;
213
214 if let Some((array_len, offset)) = patches_info {
215 let indices = children_iter
216 .next()
217 .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
218 let values = children_iter
219 .next()
220 .ok_or_else(|| vortex_err!("Expected patch values child"))?;
221
222 array.left_parts_patches = Some(Patches::new(
223 array_len, offset, indices, values,
224 None, ));
226 }
227
228 Ok(())
229 }
230}
231
232#[derive(Clone, Debug)]
233pub struct ALPRDArray {
234 dtype: DType,
235 left_parts: ArrayRef,
236 left_parts_patches: Option<Patches>,
237 left_parts_dictionary: Buffer<u16>,
238 right_parts: ArrayRef,
239 right_bit_width: u8,
240 stats_set: ArrayStats,
241}
242
243#[derive(Debug)]
244pub struct ALPRDVTable;
245
246impl ALPRDArray {
247 pub fn try_new(
249 dtype: DType,
250 left_parts: ArrayRef,
251 left_parts_dictionary: Buffer<u16>,
252 right_parts: ArrayRef,
253 right_bit_width: u8,
254 left_parts_patches: Option<Patches>,
255 ) -> VortexResult<Self> {
256 if !dtype.is_float() {
257 vortex_bail!("ALPRDArray given invalid DType ({dtype})");
258 }
259
260 let len = left_parts.len();
261 if right_parts.len() != len {
262 vortex_bail!(
263 "left_parts (len {}) and right_parts (len {}) must be of same length",
264 len,
265 right_parts.len()
266 );
267 }
268
269 if !left_parts.dtype().is_unsigned_int() {
270 vortex_bail!("left_parts dtype must be uint");
271 }
272 if dtype.is_nullable() != left_parts.dtype().is_nullable() {
274 vortex_bail!(
275 "ALPRDArray dtype nullability ({}) must match left_parts dtype nullability ({})",
276 dtype,
277 left_parts.dtype()
278 );
279 }
280
281 if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
283 vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
284 }
285
286 let left_parts_patches = left_parts_patches
287 .map(|patches| {
288 if !patches.values().all_valid() {
289 vortex_bail!("patches must be all valid: {}", patches.values());
290 }
291 patches.cast_values(left_parts.dtype())
293 })
294 .transpose()?;
295
296 Ok(Self {
297 dtype,
298 left_parts,
299 left_parts_dictionary,
300 right_parts,
301 right_bit_width,
302 left_parts_patches,
303 stats_set: Default::default(),
304 })
305 }
306
307 pub(crate) unsafe fn new_unchecked(
310 dtype: DType,
311 left_parts: ArrayRef,
312 left_parts_dictionary: Buffer<u16>,
313 right_parts: ArrayRef,
314 right_bit_width: u8,
315 left_parts_patches: Option<Patches>,
316 ) -> Self {
317 Self {
318 dtype,
319 left_parts,
320 left_parts_patches,
321 left_parts_dictionary,
322 right_parts,
323 right_bit_width,
324 stats_set: Default::default(),
325 }
326 }
327
328 #[inline]
332 pub fn is_f32(&self) -> bool {
333 matches!(&self.dtype, DType::Primitive(PType::F32, _))
334 }
335
336 pub fn left_parts(&self) -> &ArrayRef {
341 &self.left_parts
342 }
343
344 pub fn right_parts(&self) -> &ArrayRef {
346 &self.right_parts
347 }
348
349 #[inline]
350 pub fn right_bit_width(&self) -> u8 {
351 self.right_bit_width
352 }
353
354 pub fn left_parts_patches(&self) -> Option<&Patches> {
356 self.left_parts_patches.as_ref()
357 }
358
359 #[inline]
361 pub fn left_parts_dictionary(&self) -> &Buffer<u16> {
362 &self.left_parts_dictionary
363 }
364
365 pub fn replace_left_parts_patches(&mut self, patches: Option<Patches>) {
366 self.left_parts_patches = patches;
367 }
368}
369
370impl ValidityChild<ALPRDVTable> for ALPRDVTable {
371 fn validity_child(array: &ALPRDArray) -> &ArrayRef {
372 array.left_parts()
373 }
374}
375
376impl BaseArrayVTable<ALPRDVTable> for ALPRDVTable {
377 fn len(array: &ALPRDArray) -> usize {
378 array.left_parts.len()
379 }
380
381 fn dtype(array: &ALPRDArray) -> &DType {
382 &array.dtype
383 }
384
385 fn stats(array: &ALPRDArray) -> StatsSetRef<'_> {
386 array.stats_set.to_ref(array.as_ref())
387 }
388
389 fn array_hash<H: std::hash::Hasher>(array: &ALPRDArray, state: &mut H, precision: Precision) {
390 array.dtype.hash(state);
391 array.left_parts.array_hash(state, precision);
392 array.left_parts_dictionary.array_hash(state, precision);
393 array.right_parts.array_hash(state, precision);
394 array.right_bit_width.hash(state);
395 array.left_parts_patches.array_hash(state, precision);
396 }
397
398 fn array_eq(array: &ALPRDArray, other: &ALPRDArray, precision: Precision) -> bool {
399 array.dtype == other.dtype
400 && array.left_parts.array_eq(&other.left_parts, precision)
401 && array
402 .left_parts_dictionary
403 .array_eq(&other.left_parts_dictionary, precision)
404 && array.right_parts.array_eq(&other.right_parts, precision)
405 && array.right_bit_width == other.right_bit_width
406 && array
407 .left_parts_patches
408 .array_eq(&other.left_parts_patches, precision)
409 }
410}
411
412impl CanonicalVTable<ALPRDVTable> for ALPRDVTable {
413 fn canonicalize(array: &ALPRDArray) -> Canonical {
414 let left_parts = array.left_parts().to_primitive();
415 let right_parts = array.right_parts().to_primitive();
416
417 let left_parts_dict = array.left_parts_dictionary();
419
420 let decoded_array = if array.is_f32() {
421 PrimitiveArray::new(
422 alp_rd_decode::<f32>(
423 left_parts.into_buffer::<u16>(),
424 left_parts_dict,
425 array.right_bit_width,
426 right_parts.into_buffer_mut::<u32>(),
427 array.left_parts_patches(),
428 ),
429 Validity::copy_from_array(array.as_ref()),
430 )
431 } else {
432 PrimitiveArray::new(
433 alp_rd_decode::<f64>(
434 left_parts.into_buffer::<u16>(),
435 left_parts_dict,
436 array.right_bit_width,
437 right_parts.into_buffer_mut::<u64>(),
438 array.left_parts_patches(),
439 ),
440 Validity::copy_from_array(array.as_ref()),
441 )
442 };
443
444 Canonical::Primitive(decoded_array)
445 }
446}
447
448impl EncodeVTable<ALPRDVTable> for ALPRDVTable {
449 fn encode(
450 _vtable: &ALPRDVTable,
451 canonical: &Canonical,
452 like: Option<&ALPRDArray>,
453 ) -> VortexResult<Option<ALPRDArray>> {
454 let parray = canonical.clone().into_primitive();
455
456 let alprd_array = match like {
457 None => {
458 let encoder = match parray.ptype() {
459 PType::F32 => crate::alp_rd::RDEncoder::new(parray.as_slice::<f32>()),
460 PType::F64 => crate::alp_rd::RDEncoder::new(parray.as_slice::<f64>()),
461 ptype => vortex_bail!("cannot ALPRD compress ptype {ptype}"),
462 };
463 encoder.encode(&parray)
464 }
465 Some(like) => {
466 let encoder = crate::alp_rd::RDEncoder::from_parts(
467 like.right_bit_width(),
468 like.left_parts_dictionary().to_vec(),
469 );
470 encoder.encode(&parray)
471 }
472 };
473
474 Ok(Some(alprd_array))
475 }
476}
477
478impl VisitorVTable<ALPRDVTable> for ALPRDVTable {
479 fn visit_buffers(_array: &ALPRDArray, _visitor: &mut dyn ArrayBufferVisitor) {}
480
481 fn visit_children(array: &ALPRDArray, visitor: &mut dyn ArrayChildVisitor) {
482 visitor.visit_child("left_parts", array.left_parts());
483 visitor.visit_child("right_parts", array.right_parts());
484 if let Some(patches) = array.left_parts_patches() {
485 visitor.visit_patches(patches);
486 }
487 }
488}
489
490#[cfg(test)]
491mod test {
492 use rstest::rstest;
493 use vortex_array::ProstMetadata;
494 use vortex_array::ToCanonical;
495 use vortex_array::arrays::PrimitiveArray;
496 use vortex_array::assert_arrays_eq;
497 use vortex_array::patches::PatchesMetadata;
498 use vortex_array::test_harness::check_metadata;
499 use vortex_dtype::PType;
500
501 use super::ALPRDMetadata;
502 use crate::ALPRDFloat;
503 use crate::alp_rd;
504
505 #[rstest]
506 #[case(vec![0.1f32.next_up(); 1024], 1.123_848_f32)]
507 #[case(vec![0.1f64.next_up(); 1024], 1.123_848_591_110_992_f64)]
508 fn test_array_encode_with_nulls_and_patches<T: ALPRDFloat>(
509 #[case] reals: Vec<T>,
510 #[case] seed: T,
511 ) {
512 assert_eq!(reals.len(), 1024, "test expects 1024-length fixture");
513 let mut reals: Vec<Option<T>> = reals.into_iter().map(Some).collect();
515 reals[1] = None;
516 reals[5] = None;
517 reals[900] = None;
518
519 let real_array = PrimitiveArray::from_option_iter(reals.iter().cloned());
521
522 let encoder: alp_rd::RDEncoder = alp_rd::RDEncoder::new(&[seed.powi(-2)]);
524
525 let rd_array = encoder.encode(&real_array);
526
527 let decoded = rd_array.to_primitive();
528
529 assert_arrays_eq!(decoded, PrimitiveArray::from_option_iter(reals));
530 }
531
532 #[cfg_attr(miri, ignore)]
533 #[test]
534 fn test_alprd_metadata() {
535 check_metadata(
536 "alprd.metadata",
537 ProstMetadata(ALPRDMetadata {
538 right_bit_width: u32::MAX,
539 patches: Some(PatchesMetadata::new(
540 usize::MAX,
541 usize::MAX,
542 PType::U64,
543 None,
544 None,
545 None,
546 )),
547 dict: Vec::new(),
548 left_parts_ptype: PType::U64 as i32,
549 dict_len: 8,
550 }),
551 );
552 }
553}