vortex_zigzag/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::ops::Range;
6
7use vortex_array::stats::{ArrayStats, StatsSetRef};
8use vortex_array::vtable::{
9    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityChild,
10    ValidityVTableFromChild,
11};
12use vortex_array::{
13    Array, ArrayEq, ArrayHash, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, Precision,
14    ToCanonical, vtable,
15};
16use vortex_dtype::{DType, PType, match_each_unsigned_integer_ptype};
17use vortex_error::{VortexExpect, VortexResult, vortex_bail};
18use vortex_scalar::Scalar;
19use zigzag::ZigZag as ExternalZigZag;
20
21use crate::compute::ZigZagEncoded;
22use crate::zigzag_decode;
23
24vtable!(ZigZag);
25
26impl VTable for ZigZagVTable {
27    type Array = ZigZagArray;
28    type Encoding = ZigZagEncoding;
29
30    type ArrayVTable = Self;
31    type CanonicalVTable = Self;
32    type OperationsVTable = Self;
33    type ValidityVTable = ValidityVTableFromChild;
34    type VisitorVTable = Self;
35    type ComputeVTable = NotSupported;
36    type EncodeVTable = Self;
37    type SerdeVTable = Self;
38    type OperatorVTable = NotSupported;
39
40    fn id(_encoding: &Self::Encoding) -> EncodingId {
41        EncodingId::new_ref("vortex.zigzag")
42    }
43
44    fn encoding(_array: &Self::Array) -> EncodingRef {
45        EncodingRef::new_ref(ZigZagEncoding.as_ref())
46    }
47}
48
49#[derive(Clone, Debug)]
50pub struct ZigZagArray {
51    dtype: DType,
52    encoded: ArrayRef,
53    stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct ZigZagEncoding;
58
59impl ZigZagArray {
60    pub fn new(encoded: ArrayRef) -> Self {
61        Self::try_new(encoded).vortex_expect("ZigZigArray new")
62    }
63
64    pub fn try_new(encoded: ArrayRef) -> VortexResult<Self> {
65        let encoded_dtype = encoded.dtype().clone();
66        if !encoded_dtype.is_unsigned_int() {
67            vortex_bail!(MismatchedTypes: "unsigned int", encoded_dtype);
68        }
69
70        let dtype = DType::from(PType::try_from(&encoded_dtype)?.to_signed())
71            .with_nullability(encoded_dtype.nullability());
72
73        Ok(Self {
74            dtype,
75            encoded,
76            stats_set: Default::default(),
77        })
78    }
79
80    pub fn ptype(&self) -> PType {
81        self.dtype().as_ptype()
82    }
83
84    pub fn encoded(&self) -> &ArrayRef {
85        &self.encoded
86    }
87}
88
89impl ArrayVTable<ZigZagVTable> for ZigZagVTable {
90    fn len(array: &ZigZagArray) -> usize {
91        array.encoded.len()
92    }
93
94    fn dtype(array: &ZigZagArray) -> &DType {
95        &array.dtype
96    }
97
98    fn stats(array: &ZigZagArray) -> StatsSetRef<'_> {
99        array.stats_set.to_ref(array.as_ref())
100    }
101
102    fn array_hash<H: std::hash::Hasher>(array: &ZigZagArray, state: &mut H, precision: Precision) {
103        array.dtype.hash(state);
104        array.encoded.array_hash(state, precision);
105    }
106
107    fn array_eq(array: &ZigZagArray, other: &ZigZagArray, precision: Precision) -> bool {
108        array.dtype == other.dtype && array.encoded.array_eq(&other.encoded, precision)
109    }
110}
111
112impl CanonicalVTable<ZigZagVTable> for ZigZagVTable {
113    fn canonicalize(array: &ZigZagArray) -> Canonical {
114        Canonical::Primitive(zigzag_decode(array.encoded().to_primitive()))
115    }
116}
117
118impl OperationsVTable<ZigZagVTable> for ZigZagVTable {
119    fn slice(array: &ZigZagArray, range: Range<usize>) -> ArrayRef {
120        ZigZagArray::new(array.encoded().slice(range)).into_array()
121    }
122
123    fn scalar_at(array: &ZigZagArray, index: usize) -> Scalar {
124        let scalar = array.encoded().scalar_at(index);
125        if scalar.is_null() {
126            return scalar.reinterpret_cast(array.ptype());
127        }
128
129        let pscalar = scalar.as_primitive();
130        match_each_unsigned_integer_ptype!(pscalar.ptype(), |P| {
131            Scalar::primitive(
132                <<P as ZigZagEncoded>::Int>::decode(
133                    pscalar
134                        .typed_value::<P>()
135                        .vortex_expect("zigzag corruption"),
136                ),
137                array.dtype().nullability(),
138            )
139        })
140    }
141}
142
143impl ValidityChild<ZigZagVTable> for ZigZagVTable {
144    fn validity_child(array: &ZigZagArray) -> &dyn Array {
145        array.encoded()
146    }
147}
148
149#[cfg(test)]
150mod test {
151    use vortex_array::IntoArray;
152    use vortex_buffer::buffer;
153    use vortex_scalar::Scalar;
154
155    use super::*;
156
157    #[test]
158    fn test_compute_statistics() {
159        let array = buffer![1i32, -5i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
160        let canonical = array.to_canonical();
161        let zigzag = ZigZagEncoding.encode(&canonical, None).unwrap().unwrap();
162
163        assert_eq!(
164            zigzag.statistics().compute_max::<i32>(),
165            array.statistics().compute_max::<i32>()
166        );
167        assert_eq!(
168            zigzag.statistics().compute_null_count(),
169            array.statistics().compute_null_count()
170        );
171        assert_eq!(
172            zigzag.statistics().compute_is_constant(),
173            array.statistics().compute_is_constant()
174        );
175
176        let sliced = zigzag.slice(0..2);
177        let sliced = sliced.as_::<ZigZagVTable>();
178        assert_eq!(sliced.scalar_at(sliced.len() - 1), Scalar::from(-5i32));
179
180        assert_eq!(
181            sliced.statistics().compute_min::<i32>(),
182            array.statistics().compute_min::<i32>()
183        );
184        assert_eq!(
185            sliced.statistics().compute_null_count(),
186            array.statistics().compute_null_count()
187        );
188        assert_eq!(
189            sliced.statistics().compute_is_constant(),
190            array.statistics().compute_is_constant()
191        );
192    }
193}