Skip to main content

vortex_zigzag/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use vortex_array::Array;
7use vortex_array::ArrayBufferVisitor;
8use vortex_array::ArrayChildVisitor;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::EmptyMetadata;
13use vortex_array::ExecutionCtx;
14use vortex_array::IntoArray;
15use vortex_array::Precision;
16use vortex_array::buffer::BufferHandle;
17use vortex_array::dtype::DType;
18use vortex_array::dtype::PType;
19use vortex_array::match_each_unsigned_integer_ptype;
20use vortex_array::scalar::Scalar;
21use vortex_array::serde::ArrayChildren;
22use vortex_array::stats::ArrayStats;
23use vortex_array::stats::StatsSetRef;
24use vortex_array::vtable;
25use vortex_array::vtable::ArrayId;
26use vortex_array::vtable::BaseArrayVTable;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityChild;
30use vortex_array::vtable::ValidityVTableFromChild;
31use vortex_array::vtable::VisitorVTable;
32use vortex_error::VortexExpect;
33use vortex_error::VortexResult;
34use vortex_error::vortex_bail;
35use vortex_error::vortex_ensure;
36use vortex_session::VortexSession;
37use zigzag::ZigZag as ExternalZigZag;
38
39use crate::compute::ZigZagEncoded;
40use crate::kernel::PARENT_KERNELS;
41use crate::rules::RULES;
42use crate::zigzag_decode;
43
44vtable!(ZigZag);
45
46impl VTable for ZigZagVTable {
47    type Array = ZigZagArray;
48
49    type Metadata = EmptyMetadata;
50
51    type ArrayVTable = Self;
52    type OperationsVTable = Self;
53    type ValidityVTable = ValidityVTableFromChild;
54    type VisitorVTable = Self;
55
56    fn id(_array: &Self::Array) -> ArrayId {
57        Self::ID
58    }
59
60    fn metadata(_array: &ZigZagArray) -> VortexResult<Self::Metadata> {
61        Ok(EmptyMetadata)
62    }
63
64    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
65        Ok(Some(vec![]))
66    }
67
68    fn deserialize(
69        _bytes: &[u8],
70        _dtype: &DType,
71        _len: usize,
72        _buffers: &[BufferHandle],
73        _session: &VortexSession,
74    ) -> VortexResult<Self::Metadata> {
75        Ok(EmptyMetadata)
76    }
77
78    fn build(
79        dtype: &DType,
80        len: usize,
81        _metadata: &Self::Metadata,
82        _buffers: &[BufferHandle],
83        children: &dyn ArrayChildren,
84    ) -> VortexResult<ZigZagArray> {
85        if children.len() != 1 {
86            vortex_bail!("Expected 1 child, got {}", children.len());
87        }
88
89        let ptype = PType::try_from(dtype)?;
90        let encoded_type = DType::Primitive(ptype.to_unsigned(), dtype.nullability());
91
92        let encoded = children.get(0, &encoded_type, len)?;
93        ZigZagArray::try_new(encoded)
94    }
95
96    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
97        vortex_ensure!(
98            children.len() == 1,
99            "ZigZagArray expects exactly 1 child (encoded), got {}",
100            children.len()
101        );
102        array.encoded = children.into_iter().next().vortex_expect("checked");
103        Ok(())
104    }
105
106    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
107        Ok(zigzag_decode(array.encoded().clone().execute(ctx)?).into_array())
108    }
109
110    fn reduce_parent(
111        array: &Self::Array,
112        parent: &ArrayRef,
113        child_idx: usize,
114    ) -> VortexResult<Option<ArrayRef>> {
115        RULES.evaluate(array, parent, child_idx)
116    }
117
118    fn execute_parent(
119        array: &Self::Array,
120        parent: &ArrayRef,
121        child_idx: usize,
122        ctx: &mut ExecutionCtx,
123    ) -> VortexResult<Option<ArrayRef>> {
124        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
125    }
126}
127
128#[derive(Clone, Debug)]
129pub struct ZigZagArray {
130    dtype: DType,
131    encoded: ArrayRef,
132    stats_set: ArrayStats,
133}
134
135#[derive(Debug)]
136pub struct ZigZagVTable;
137
138impl ZigZagVTable {
139    pub const ID: ArrayId = ArrayId::new_ref("vortex.zigzag");
140}
141
142impl ZigZagArray {
143    pub fn new(encoded: ArrayRef) -> Self {
144        Self::try_new(encoded).vortex_expect("ZigZagArray new")
145    }
146
147    pub fn try_new(encoded: ArrayRef) -> VortexResult<Self> {
148        let encoded_dtype = encoded.dtype().clone();
149        if !encoded_dtype.is_unsigned_int() {
150            vortex_bail!(MismatchedTypes: "unsigned int", encoded_dtype);
151        }
152
153        let dtype = DType::from(PType::try_from(&encoded_dtype)?.to_signed())
154            .with_nullability(encoded_dtype.nullability());
155
156        Ok(Self {
157            dtype,
158            encoded,
159            stats_set: Default::default(),
160        })
161    }
162
163    pub fn ptype(&self) -> PType {
164        self.dtype().as_ptype()
165    }
166
167    pub fn encoded(&self) -> &ArrayRef {
168        &self.encoded
169    }
170}
171
172impl BaseArrayVTable<ZigZagVTable> for ZigZagVTable {
173    fn len(array: &ZigZagArray) -> usize {
174        array.encoded.len()
175    }
176
177    fn dtype(array: &ZigZagArray) -> &DType {
178        &array.dtype
179    }
180
181    fn stats(array: &ZigZagArray) -> StatsSetRef<'_> {
182        array.stats_set.to_ref(array.as_ref())
183    }
184
185    fn array_hash<H: std::hash::Hasher>(array: &ZigZagArray, state: &mut H, precision: Precision) {
186        array.dtype.hash(state);
187        array.encoded.array_hash(state, precision);
188    }
189
190    fn array_eq(array: &ZigZagArray, other: &ZigZagArray, precision: Precision) -> bool {
191        array.dtype == other.dtype && array.encoded.array_eq(&other.encoded, precision)
192    }
193}
194
195impl OperationsVTable<ZigZagVTable> for ZigZagVTable {
196    fn scalar_at(array: &ZigZagArray, index: usize) -> VortexResult<Scalar> {
197        let scalar = array.encoded().scalar_at(index)?;
198        if scalar.is_null() {
199            return scalar.primitive_reinterpret_cast(array.ptype());
200        }
201
202        let pscalar = scalar.as_primitive();
203        Ok(match_each_unsigned_integer_ptype!(pscalar.ptype(), |P| {
204            Scalar::primitive(
205                <<P as ZigZagEncoded>::Int>::decode(
206                    pscalar
207                        .typed_value::<P>()
208                        .vortex_expect("zigzag corruption"),
209                ),
210                array.dtype().nullability(),
211            )
212        }))
213    }
214}
215
216impl ValidityChild<ZigZagVTable> for ZigZagVTable {
217    fn validity_child(array: &ZigZagArray) -> &ArrayRef {
218        array.encoded()
219    }
220}
221
222impl VisitorVTable<ZigZagVTable> for ZigZagVTable {
223    fn visit_buffers(_array: &ZigZagArray, _visitor: &mut dyn ArrayBufferVisitor) {}
224
225    fn nbuffers(_array: &ZigZagArray) -> usize {
226        0
227    }
228
229    fn visit_children(array: &ZigZagArray, visitor: &mut dyn ArrayChildVisitor) {
230        visitor.visit_child("encoded", array.encoded())
231    }
232
233    fn nchildren(_array: &ZigZagArray) -> usize {
234        1
235    }
236}
237
238#[cfg(test)]
239mod test {
240    use vortex_array::IntoArray;
241    use vortex_array::ToCanonical;
242    use vortex_array::scalar::Scalar;
243    use vortex_buffer::buffer;
244
245    use super::*;
246    use crate::zigzag_encode;
247
248    #[test]
249    fn test_compute_statistics() -> VortexResult<()> {
250        let array = buffer![1i32, -5i32, 2, 3, 4, 5, 6, 7, 8, 9, 10]
251            .into_array()
252            .to_primitive();
253        let zigzag = zigzag_encode(array.clone())?;
254
255        assert_eq!(
256            zigzag.statistics().compute_max::<i32>(),
257            array.statistics().compute_max::<i32>()
258        );
259        assert_eq!(
260            zigzag.statistics().compute_null_count(),
261            array.statistics().compute_null_count()
262        );
263        assert_eq!(
264            zigzag.statistics().compute_is_constant(),
265            array.statistics().compute_is_constant()
266        );
267
268        let sliced = zigzag.slice(0..2).unwrap();
269        let sliced = sliced.as_::<ZigZagVTable>();
270        assert_eq!(
271            sliced.scalar_at(sliced.len() - 1).unwrap(),
272            Scalar::from(-5i32)
273        );
274
275        assert_eq!(
276            sliced.statistics().compute_min::<i32>(),
277            array.statistics().compute_min::<i32>()
278        );
279        assert_eq!(
280            sliced.statistics().compute_null_count(),
281            array.statistics().compute_null_count()
282        );
283        assert_eq!(
284            sliced.statistics().compute_is_constant(),
285            array.statistics().compute_is_constant()
286        );
287        Ok(())
288    }
289}