vortex_alp/alp/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_array::Array;
8use vortex_array::ArrayBufferVisitor;
9use vortex_array::ArrayChildVisitor;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::DeserializeMetadata;
15use vortex_array::ExecutionCtx;
16use vortex_array::Precision;
17use vortex_array::ProstMetadata;
18use vortex_array::SerializeMetadata;
19use vortex_array::VectorExecutor;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::patches::Patches;
22use vortex_array::patches::PatchesMetadata;
23use vortex_array::serde::ArrayChildren;
24use vortex_array::stats::ArrayStats;
25use vortex_array::stats::StatsSetRef;
26use vortex_array::vtable;
27use vortex_array::vtable::ArrayId;
28use vortex_array::vtable::ArrayVTable;
29use vortex_array::vtable::ArrayVTableExt;
30use vortex_array::vtable::BaseArrayVTable;
31use vortex_array::vtable::CanonicalVTable;
32use vortex_array::vtable::EncodeVTable;
33use vortex_array::vtable::NotSupported;
34use vortex_array::vtable::VTable;
35use vortex_array::vtable::ValidityChild;
36use vortex_array::vtable::ValidityVTableFromChild;
37use vortex_array::vtable::VisitorVTable;
38use vortex_dtype::DType;
39use vortex_dtype::PType;
40use vortex_error::VortexError;
41use vortex_error::VortexExpect;
42use vortex_error::VortexResult;
43use vortex_error::vortex_bail;
44use vortex_error::vortex_ensure;
45use vortex_error::vortex_err;
46use vortex_vector::Vector;
47
48use crate::ALPFloat;
49use crate::alp::Exponents;
50use crate::alp::alp_encode;
51use crate::alp::decompress::decompress_into_array;
52use crate::alp::decompress::decompress_into_vector;
53use crate::match_each_alp_float_ptype;
54
55vtable!(ALP);
56
57impl VTable for ALPVTable {
58    type Array = ALPArray;
59
60    type Metadata = ProstMetadata<ALPMetadata>;
61
62    type ArrayVTable = Self;
63    type CanonicalVTable = Self;
64    type OperationsVTable = Self;
65    type ValidityVTable = ValidityVTableFromChild;
66    type VisitorVTable = Self;
67    type ComputeVTable = NotSupported;
68    type EncodeVTable = Self;
69
70    fn id(&self) -> ArrayId {
71        ArrayId::new_ref("vortex.alp")
72    }
73
74    fn encoding(_array: &Self::Array) -> ArrayVTable {
75        ALPVTable.as_vtable()
76    }
77
78    fn metadata(array: &ALPArray) -> VortexResult<Self::Metadata> {
79        let exponents = array.exponents();
80        Ok(ProstMetadata(ALPMetadata {
81            exp_e: exponents.e as u32,
82            exp_f: exponents.f as u32,
83            patches: array
84                .patches()
85                .map(|p| p.to_metadata(array.len(), array.dtype()))
86                .transpose()?,
87        }))
88    }
89
90    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
91        Ok(Some(metadata.serialize()))
92    }
93
94    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
95        Ok(ProstMetadata(
96            <ProstMetadata<ALPMetadata> as DeserializeMetadata>::deserialize(buffer)?,
97        ))
98    }
99
100    fn build(
101        &self,
102        dtype: &DType,
103        len: usize,
104        metadata: &Self::Metadata,
105        _buffers: &[BufferHandle],
106        children: &dyn ArrayChildren,
107    ) -> VortexResult<ALPArray> {
108        let encoded_ptype = match &dtype {
109            DType::Primitive(PType::F32, n) => DType::Primitive(PType::I32, *n),
110            DType::Primitive(PType::F64, n) => DType::Primitive(PType::I64, *n),
111            d => vortex_bail!(MismatchedTypes: "f32 or f64", d),
112        };
113        let encoded = children.get(0, &encoded_ptype, len)?;
114
115        let patches = metadata
116            .patches
117            .map(|p| {
118                let indices = children.get(1, &p.indices_dtype(), p.len())?;
119                let values = children.get(2, dtype, p.len())?;
120                let chunk_offsets = p
121                    .chunk_offsets_dtype()
122                    .map(|dtype| children.get(3, &dtype, usize::try_from(p.chunk_offsets_len())?))
123                    .transpose()?;
124
125                Ok::<_, VortexError>(Patches::new(
126                    len,
127                    p.offset(),
128                    indices,
129                    values,
130                    chunk_offsets,
131                ))
132            })
133            .transpose()?;
134
135        ALPArray::try_new(
136            encoded,
137            Exponents {
138                e: u8::try_from(metadata.exp_e)?,
139                f: u8::try_from(metadata.exp_f)?,
140            },
141            patches,
142        )
143    }
144
145    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
146        // Children: encoded, patches (if present): indices, values, chunk_offsets (optional)
147        let patches_info = array
148            .patches
149            .as_ref()
150            .map(|p| (p.array_len(), p.offset(), p.chunk_offsets().is_some()));
151
152        let expected_children = match &patches_info {
153            Some((_, _, has_chunk_offsets)) => 1 + 2 + if *has_chunk_offsets { 1 } else { 0 },
154            None => 1,
155        };
156
157        vortex_ensure!(
158            children.len() == expected_children,
159            "ALPArray expects {} children, got {}",
160            expected_children,
161            children.len()
162        );
163
164        let mut children_iter = children.into_iter();
165        array.encoded = children_iter
166            .next()
167            .ok_or_else(|| vortex_err!("Expected encoded child"))?;
168
169        if let Some((array_len, offset, _has_chunk_offsets)) = patches_info {
170            let indices = children_iter
171                .next()
172                .ok_or_else(|| vortex_err!("Expected patch indices child"))?;
173            let values = children_iter
174                .next()
175                .ok_or_else(|| vortex_err!("Expected patch values child"))?;
176            let chunk_offsets = children_iter.next();
177
178            array.patches = Some(Patches::new(
179                array_len,
180                offset,
181                indices,
182                values,
183                chunk_offsets,
184            ));
185        }
186
187        Ok(())
188    }
189
190    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
191        let encoded = array.encoded().execute(ctx)?;
192        let patches = if let Some(patches) = array.patches() {
193            Some((
194                patches.indices().execute(ctx)?,
195                patches.values().execute(ctx)?,
196                patches
197                    .chunk_offsets()
198                    .as_ref()
199                    .map(|co| co.execute(ctx))
200                    .transpose()?,
201            ))
202        } else {
203            None
204        };
205
206        let patches_offset = array.patches().map(|p| p.offset()).unwrap_or(0);
207        let exponents = array.exponents();
208
209        match_each_alp_float_ptype!(array.dtype().as_ptype(), |T| {
210            decompress_into_vector::<T>(encoded, exponents, patches, patches_offset)
211        })
212    }
213}
214
215#[derive(Clone, Debug)]
216pub struct ALPArray {
217    encoded: ArrayRef,
218    patches: Option<Patches>,
219    dtype: DType,
220    exponents: Exponents,
221    stats_set: ArrayStats,
222}
223
224#[derive(Debug)]
225pub struct ALPVTable;
226
227#[derive(Clone, prost::Message)]
228pub struct ALPMetadata {
229    #[prost(uint32, tag = "1")]
230    pub(crate) exp_e: u32,
231    #[prost(uint32, tag = "2")]
232    pub(crate) exp_f: u32,
233    #[prost(message, optional, tag = "3")]
234    pub(crate) patches: Option<PatchesMetadata>,
235}
236
237impl ALPArray {
238    fn validate(
239        encoded: &dyn Array,
240        exponents: Exponents,
241        patches: Option<&Patches>,
242    ) -> VortexResult<()> {
243        vortex_ensure!(
244            matches!(
245                encoded.dtype(),
246                DType::Primitive(PType::I32 | PType::I64, _)
247            ),
248            "ALP encoded ints have invalid DType {}",
249            encoded.dtype(),
250        );
251
252        // Validate exponents are in-bounds for the float, and that patches have the proper
253        // length and type.
254        let Exponents { e, f } = exponents;
255        match encoded.dtype().as_ptype() {
256            PType::I32 => {
257                vortex_ensure!(exponents.e <= f32::MAX_EXPONENT, "e out of bounds: {e}");
258                vortex_ensure!(exponents.f <= f32::MAX_EXPONENT, "f out of bounds: {f}");
259                if let Some(patches) = patches {
260                    Self::validate_patches::<f32>(patches, encoded)?;
261                }
262            }
263            PType::I64 => {
264                vortex_ensure!(e <= f64::MAX_EXPONENT, "e out of bounds: {e}");
265                vortex_ensure!(f <= f64::MAX_EXPONENT, "f out of bounds: {f}");
266
267                if let Some(patches) = patches {
268                    Self::validate_patches::<f64>(patches, encoded)?;
269                }
270            }
271            _ => unreachable!(),
272        }
273
274        // Validate patches
275        if let Some(patches) = patches {
276            vortex_ensure!(
277                patches.array_len() == encoded.len(),
278                "patches array_len != encoded len: {} != {}",
279                patches.array_len(),
280                encoded.len()
281            );
282
283            // Verify that the patches DType are of the proper DType.
284        }
285
286        Ok(())
287    }
288
289    /// Validate that any patches provided are valid for the ALPArray.
290    fn validate_patches<T: ALPFloat>(patches: &Patches, encoded: &dyn Array) -> VortexResult<()> {
291        vortex_ensure!(
292            patches.array_len() == encoded.len(),
293            "patches array_len != encoded len: {} != {}",
294            patches.array_len(),
295            encoded.len()
296        );
297
298        let expected_type = DType::Primitive(T::PTYPE, encoded.dtype().nullability());
299        vortex_ensure!(
300            patches.dtype() == &expected_type,
301            "Expected patches type {expected_type}, actual {}",
302            patches.dtype(),
303        );
304
305        Ok(())
306    }
307}
308
309impl ALPArray {
310    /// Build a new `ALPArray` from components, panicking on validation failure.
311    ///
312    /// See [`ALPArray::try_new`] for reference on preconditions that must pass before
313    /// calling this method.
314    pub fn new(encoded: ArrayRef, exponents: Exponents, patches: Option<Patches>) -> Self {
315        Self::try_new(encoded, exponents, patches).vortex_expect("ALPArray new")
316    }
317
318    /// Build a new `ALPArray` from components:
319    ///
320    /// * `encoded` contains the ALP-encoded ints. Any null values are replaced with placeholders
321    /// * `exponents` are the ALP exponents, valid range depends on the data type
322    /// * `patches` are any patch values that don't cleanly encode using the ALP conversion function
323    ///
324    /// This method validates the inputs and will return an error if any validation fails.
325    ///
326    /// # Validation
327    ///
328    /// * The `encoded` array must be either `i32` or `i64`
329    ///     * If `i32`, any `patches` must have DType `f32` with same nullability
330    ///     * If `i64`, then `patches`must have DType `f64` with same nullability
331    /// * `exponents` must be in the valid range depending on if the ALPArray is of type `f32` or
332    ///   `f64`.
333    /// * `patches` must have an `array_len` equal to the length of `encoded`
334    ///
335    /// Any failure of these preconditions will result in an error being returned.
336    ///
337    /// # Examples
338    ///
339    /// ```
340    /// # use vortex_alp::{ALPArray, Exponents};
341    /// # use vortex_array::IntoArray;
342    /// # use vortex_buffer::buffer;
343    ///
344    /// // Returns error because buffer has wrong PType.
345    /// let result = ALPArray::try_new(
346    ///     buffer![1i8].into_array(),
347    ///     Exponents { e: 1, f: 1 },
348    ///     None
349    /// );
350    /// assert!(result.is_err());
351    ///
352    /// // Returns error because Exponents are out of bounds for f32
353    /// let result = ALPArray::try_new(
354    ///     buffer![1i32, 2i32].into_array(),
355    ///     Exponents { e: 100, f: 100 },
356    ///     None
357    /// );
358    /// assert!(result.is_err());
359    ///
360    /// // Success!
361    /// let value = ALPArray::try_new(
362    ///     buffer![0i32].into_array(),
363    ///     Exponents { e: 1, f: 1 },
364    ///     None
365    /// ).unwrap();
366    ///
367    /// assert_eq!(value.scalar_at(0), 0f32.into());
368    /// ```
369    pub fn try_new(
370        encoded: ArrayRef,
371        exponents: Exponents,
372        patches: Option<Patches>,
373    ) -> VortexResult<Self> {
374        Self::validate(&encoded, exponents, patches.as_ref())?;
375
376        let dtype = match encoded.dtype() {
377            DType::Primitive(PType::I32, nullability) => DType::Primitive(PType::F32, *nullability),
378            DType::Primitive(PType::I64, nullability) => DType::Primitive(PType::F64, *nullability),
379            _ => unreachable!(),
380        };
381
382        Ok(Self {
383            dtype,
384            encoded,
385            exponents,
386            patches,
387            stats_set: Default::default(),
388        })
389    }
390
391    /// Build a new `ALPArray` from components without validation.
392    ///
393    /// See [`ALPArray::try_new`] for information about the preconditions that should be checked
394    /// **before** calling this method.
395    pub(crate) unsafe fn new_unchecked(
396        encoded: ArrayRef,
397        exponents: Exponents,
398        patches: Option<Patches>,
399        dtype: DType,
400    ) -> Self {
401        Self {
402            dtype,
403            encoded,
404            exponents,
405            patches,
406            stats_set: Default::default(),
407        }
408    }
409
410    pub fn ptype(&self) -> PType {
411        self.dtype.as_ptype()
412    }
413
414    pub fn encoded(&self) -> &ArrayRef {
415        &self.encoded
416    }
417
418    #[inline]
419    pub fn exponents(&self) -> Exponents {
420        self.exponents
421    }
422
423    pub fn patches(&self) -> Option<&Patches> {
424        self.patches.as_ref()
425    }
426
427    /// Consumes the array and returns its parts.
428    #[inline]
429    pub fn into_parts(self) -> (ArrayRef, Exponents, Option<Patches>, DType) {
430        (self.encoded, self.exponents, self.patches, self.dtype)
431    }
432}
433
434impl ValidityChild<ALPVTable> for ALPVTable {
435    fn validity_child(array: &ALPArray) -> &ArrayRef {
436        array.encoded()
437    }
438}
439
440impl BaseArrayVTable<ALPVTable> for ALPVTable {
441    fn len(array: &ALPArray) -> usize {
442        array.encoded.len()
443    }
444
445    fn dtype(array: &ALPArray) -> &DType {
446        &array.dtype
447    }
448
449    fn stats(array: &ALPArray) -> StatsSetRef<'_> {
450        array.stats_set.to_ref(array.as_ref())
451    }
452
453    fn array_hash<H: std::hash::Hasher>(array: &ALPArray, state: &mut H, precision: Precision) {
454        array.dtype.hash(state);
455        array.encoded.array_hash(state, precision);
456        array.exponents.hash(state);
457        array.patches.array_hash(state, precision);
458    }
459
460    fn array_eq(array: &ALPArray, other: &ALPArray, precision: Precision) -> bool {
461        array.dtype == other.dtype
462            && array.encoded.array_eq(&other.encoded, precision)
463            && array.exponents == other.exponents
464            && array.patches.array_eq(&other.patches, precision)
465    }
466}
467
468impl CanonicalVTable<ALPVTable> for ALPVTable {
469    fn canonicalize(array: &ALPArray) -> Canonical {
470        Canonical::Primitive(decompress_into_array(array.clone()))
471    }
472}
473
474impl EncodeVTable<ALPVTable> for ALPVTable {
475    fn encode(
476        _vtable: &ALPVTable,
477        canonical: &Canonical,
478        like: Option<&ALPArray>,
479    ) -> VortexResult<Option<ALPArray>> {
480        let parray = canonical.clone().into_primitive();
481        let exponents = like.map(|a| a.exponents());
482        let alp = alp_encode(&parray, exponents)?;
483
484        Ok(Some(alp))
485    }
486}
487
488impl VisitorVTable<ALPVTable> for ALPVTable {
489    fn visit_buffers(_array: &ALPArray, _visitor: &mut dyn ArrayBufferVisitor) {}
490
491    fn visit_children(array: &ALPArray, visitor: &mut dyn ArrayChildVisitor) {
492        visitor.visit_child("encoded", array.encoded());
493        if let Some(patches) = array.patches() {
494            visitor.visit_patches(patches);
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use std::f64::consts::PI;
502    use std::sync::LazyLock;
503
504    use rstest::rstest;
505    use vortex_array::VectorExecutor;
506    use vortex_array::arrays::PrimitiveArray;
507    use vortex_array::session::ArraySession;
508    use vortex_array::vtable::ValidityHelper;
509    use vortex_dtype::PTypeDowncast;
510    use vortex_session::VortexSession;
511    use vortex_vector::VectorOps;
512
513    use super::*;
514
515    static SESSION: LazyLock<VortexSession> =
516        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
517
518    #[rstest]
519    #[case(0)]
520    #[case(1)]
521    #[case(100)]
522    #[case(1023)]
523    #[case(1024)]
524    #[case(1025)]
525    #[case(2047)]
526    #[case(2048)]
527    #[case(2049)]
528    fn test_execute_f32(#[case] size: usize) {
529        let values = PrimitiveArray::from_iter((0..size).map(|i| i as f32));
530        let encoded = alp_encode(&values, None).unwrap();
531
532        let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap();
533        // Compare against the traditional array-based decompress path
534        let expected = decompress_into_array(encoded);
535
536        assert_eq!(result_vector.len(), size);
537
538        let result_primitive = result_vector.into_primitive().into_f32();
539        assert_eq!(result_primitive.as_ref(), expected.as_slice::<f32>());
540    }
541
542    #[rstest]
543    #[case(0)]
544    #[case(1)]
545    #[case(100)]
546    #[case(1023)]
547    #[case(1024)]
548    #[case(1025)]
549    #[case(2047)]
550    #[case(2048)]
551    #[case(2049)]
552    fn test_execute_f64(#[case] size: usize) {
553        let values = PrimitiveArray::from_iter((0..size).map(|i| i as f64));
554        let encoded = alp_encode(&values, None).unwrap();
555
556        let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap();
557        // Compare against the traditional array-based decompress path
558        let expected = decompress_into_array(encoded);
559
560        assert_eq!(result_vector.len(), size);
561
562        let result_primitive = result_vector.into_primitive().into_f64();
563        assert_eq!(result_primitive.as_ref(), expected.as_slice::<f64>());
564    }
565
566    #[rstest]
567    #[case(100)]
568    #[case(1023)]
569    #[case(1024)]
570    #[case(1025)]
571    #[case(2047)]
572    #[case(2048)]
573    #[case(2049)]
574    fn test_execute_with_patches(#[case] size: usize) {
575        let values: Vec<f64> = (0..size)
576            .map(|i| match i % 4 {
577                0..=2 => 1.0,
578                _ => PI,
579            })
580            .collect();
581
582        let array = PrimitiveArray::from_iter(values);
583        let encoded = alp_encode(&array, None).unwrap();
584        assert!(encoded.patches().unwrap().array_len() > 0);
585
586        let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap();
587        // Compare against the traditional array-based decompress path
588        let expected = decompress_into_array(encoded);
589
590        assert_eq!(result_vector.len(), size);
591
592        let result_primitive = result_vector.into_primitive().into_f64();
593        assert_eq!(result_primitive.as_ref(), expected.as_slice::<f64>());
594    }
595
596    #[rstest]
597    #[case(0)]
598    #[case(1)]
599    #[case(100)]
600    #[case(1023)]
601    #[case(1024)]
602    #[case(1025)]
603    #[case(2047)]
604    #[case(2048)]
605    #[case(2049)]
606    fn test_execute_with_validity(#[case] size: usize) {
607        let values: Vec<Option<f32>> = (0..size)
608            .map(|i| if i % 2 == 1 { None } else { Some(1.0) })
609            .collect();
610
611        let array = PrimitiveArray::from_option_iter(values);
612        let encoded = alp_encode(&array, None).unwrap();
613
614        let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap();
615        // Compare against the traditional array-based decompress path
616        let expected = decompress_into_array(encoded);
617
618        assert_eq!(result_vector.len(), size);
619
620        let result_primitive = result_vector.into_primitive().into_f32();
621        assert_eq!(result_primitive.as_ref(), expected.as_slice::<f32>());
622
623        // Test validity masks match
624        for idx in 0..size {
625            assert_eq!(
626                result_primitive.validity().value(idx),
627                expected.validity().is_valid(idx)
628            );
629        }
630    }
631
632    #[rstest]
633    #[case(100)]
634    #[case(1023)]
635    #[case(1024)]
636    #[case(1025)]
637    #[case(2047)]
638    #[case(2048)]
639    #[case(2049)]
640    fn test_execute_with_patches_and_validity(#[case] size: usize) {
641        let values: Vec<Option<f64>> = (0..size)
642            .map(|idx| match idx % 3 {
643                0 => Some(1.0),
644                1 => None,
645                _ => Some(PI),
646            })
647            .collect();
648
649        let array = PrimitiveArray::from_option_iter(values);
650        let encoded = alp_encode(&array, None).unwrap();
651        assert!(encoded.patches().unwrap().array_len() > 0);
652
653        let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap();
654        // Compare against the traditional array-based decompress path
655        let expected = decompress_into_array(encoded);
656
657        assert_eq!(result_vector.len(), size);
658
659        let result_primitive = result_vector.into_primitive().into_f64();
660        assert_eq!(result_primitive.as_ref(), expected.as_slice::<f64>());
661
662        // Test validity masks match
663        for idx in 0..size {
664            assert_eq!(
665                result_primitive.validity().value(idx),
666                expected.validity().is_valid(idx)
667            );
668        }
669    }
670
671    #[rstest]
672    #[case(500, 100)]
673    #[case(1000, 200)]
674    #[case(2048, 512)]
675    fn test_execute_sliced_vector(#[case] size: usize, #[case] slice_start: usize) {
676        let values: Vec<Option<f64>> = (0..size)
677            .map(|i| {
678                if i % 5 == 0 {
679                    None
680                } else if i % 4 == 3 {
681                    Some(PI)
682                } else {
683                    Some(1.0)
684                }
685            })
686            .collect();
687
688        let array = PrimitiveArray::from_option_iter(values.clone());
689        let encoded = alp_encode(&array, None).unwrap();
690
691        let slice_end = size - slice_start;
692        let slice_len = slice_end - slice_start;
693        let sliced_encoded = encoded.slice(slice_start..slice_end);
694
695        let result_vector = sliced_encoded.execute_vector(&SESSION).unwrap();
696        let result_primitive = result_vector.into_primitive().into_f64();
697
698        for idx in 0..slice_len {
699            let expected_value = values[slice_start + idx];
700
701            let result_valid = result_primitive.validity().value(idx);
702            assert_eq!(
703                result_valid,
704                expected_value.is_some(),
705                "Validity mismatch at idx={idx}",
706            );
707
708            if let Some(expected_val) = expected_value {
709                let result_val = result_primitive.as_ref()[idx];
710                assert_eq!(result_val, expected_val, "Value mismatch at idx={idx}",);
711            }
712        }
713    }
714}