vortex_zigzag/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{
5    FilterKernel, FilterKernelAdapter, MaskKernel, MaskKernelAdapter, TakeKernel,
6    TakeKernelAdapter, filter, mask, take,
7};
8use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::{ZigZagArray, ZigZagVTable};
13
14impl FilterKernel for ZigZagVTable {
15    fn filter(&self, array: &ZigZagArray, mask: &Mask) -> VortexResult<ArrayRef> {
16        let encoded = filter(array.encoded(), mask)?;
17        Ok(ZigZagArray::try_new(encoded)?.into_array())
18    }
19}
20
21register_kernel!(FilterKernelAdapter(ZigZagVTable).lift());
22
23impl TakeKernel for ZigZagVTable {
24    fn take(&self, array: &ZigZagArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25        let encoded = take(array.encoded(), indices)?;
26        Ok(ZigZagArray::try_new(encoded)?.into_array())
27    }
28}
29
30register_kernel!(TakeKernelAdapter(ZigZagVTable).lift());
31
32impl MaskKernel for ZigZagVTable {
33    fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult<ArrayRef> {
34        let encoded = mask(array.encoded(), filter_mask)?;
35        Ok(ZigZagArray::try_new(encoded)?.into_array())
36    }
37}
38
39register_kernel!(MaskKernelAdapter(ZigZagVTable).lift());
40
41pub(crate) trait ZigZagEncoded {
42    type Int: zigzag::ZigZag;
43}
44
45impl ZigZagEncoded for u8 {
46    type Int = i8;
47}
48
49impl ZigZagEncoded for u16 {
50    type Int = i16;
51}
52
53impl ZigZagEncoded for u32 {
54    type Int = i32;
55}
56
57impl ZigZagEncoded for u64 {
58    type Int = i64;
59}
60
61#[cfg(test)]
62mod tests {
63    use rstest::rstest;
64    use vortex_array::arrays::{BooleanBuffer, PrimitiveArray};
65    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
66    use vortex_array::compute::conformance::consistency::test_array_consistency;
67    use vortex_array::compute::{filter, take};
68    use vortex_array::validity::Validity;
69    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
70    use vortex_buffer::buffer;
71    use vortex_dtype::Nullability;
72    use vortex_scalar::Scalar;
73
74    use crate::{ZigZagArray, ZigZagEncoding, zigzag_encode};
75
76    #[test]
77    pub fn nullable_scalar_at() {
78        let zigzag = ZigZagEncoding
79            .encode(
80                &PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid)
81                    .to_canonical()
82                    .unwrap(),
83                None,
84            )
85            .unwrap()
86            .unwrap();
87        assert_eq!(
88            zigzag.scalar_at(1).unwrap(),
89            Scalar::primitive(-160, Nullability::Nullable)
90        );
91    }
92
93    #[test]
94    fn take_zigzag() {
95        let zigzag = ZigZagEncoding
96            .encode(
97                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
98                None,
99            )
100            .unwrap()
101            .unwrap();
102
103        let indices = buffer![0, 2].into_array();
104        let actual = take(&zigzag, &indices).unwrap().to_primitive().unwrap();
105        let expected = ZigZagEncoding
106            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
107            .unwrap()
108            .unwrap()
109            .to_primitive()
110            .unwrap();
111        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
112    }
113
114    #[test]
115    fn filter_zigzag() {
116        let zigzag = ZigZagEncoding
117            .encode(
118                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
119                None,
120            )
121            .unwrap()
122            .unwrap();
123        let filter_mask = BooleanBuffer::from(vec![true, false, true]).into();
124        let actual = filter(&zigzag, &filter_mask)
125            .unwrap()
126            .to_primitive()
127            .unwrap();
128        let expected = ZigZagEncoding
129            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
130            .unwrap()
131            .unwrap()
132            .to_primitive()
133            .unwrap();
134        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
135    }
136
137    #[test]
138    fn test_filter_conformance() {
139        use vortex_array::compute::conformance::filter::test_filter_conformance;
140
141        // Test with i32 values
142        let zigzag = ZigZagEncoding
143            .encode(
144                &buffer![-189i32, -160, 1, 42, -73]
145                    .into_array()
146                    .to_canonical()
147                    .unwrap(),
148                None,
149            )
150            .unwrap()
151            .unwrap();
152        test_filter_conformance(zigzag.as_ref());
153
154        // Test with i64 values
155        let zigzag = ZigZagEncoding
156            .encode(
157                &buffer![1000i64, -2000, 3000, -4000, 5000]
158                    .into_array()
159                    .to_canonical()
160                    .unwrap(),
161                None,
162            )
163            .unwrap()
164            .unwrap();
165        test_filter_conformance(zigzag.as_ref());
166
167        // Test with nullable values
168        let array =
169            PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]);
170        let zigzag = ZigZagEncoding
171            .encode(&array.to_canonical().unwrap(), None)
172            .unwrap()
173            .unwrap();
174        test_filter_conformance(zigzag.as_ref());
175    }
176
177    #[test]
178    fn test_mask_conformance() {
179        use vortex_array::compute::conformance::mask::test_mask_conformance;
180
181        // Test with i32 values
182        let zigzag = ZigZagEncoding
183            .encode(
184                &buffer![-100i32, 200, -300, 400, -500]
185                    .into_array()
186                    .to_canonical()
187                    .unwrap(),
188                None,
189            )
190            .unwrap()
191            .unwrap();
192        test_mask_conformance(zigzag.as_ref());
193
194        // Test with i8 values
195        let zigzag = ZigZagEncoding
196            .encode(
197                &buffer![-127i8, 0, 127, -1, 1]
198                    .into_array()
199                    .to_canonical()
200                    .unwrap(),
201                None,
202            )
203            .unwrap()
204            .unwrap();
205        test_mask_conformance(zigzag.as_ref());
206    }
207
208    #[rstest]
209    #[case(buffer![-189i32, -160, 1, 42, -73].into_array())]
210    #[case(buffer![1000i64, -2000, 3000, -4000, 5000].into_array())]
211    #[case(PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]).into_array())]
212    #[case(buffer![42i32].into_array())]
213    fn test_take_zigzag_conformance(#[case] array: ArrayRef) {
214        use vortex_array::compute::conformance::take::test_take_conformance;
215
216        let zigzag = ZigZagEncoding
217            .encode(&array.to_canonical().unwrap(), None)
218            .unwrap()
219            .unwrap();
220        test_take_conformance(zigzag.as_ref());
221    }
222
223    #[rstest]
224    // Basic ZigZag arrays
225    #[case::zigzag_i8(zigzag_encode(PrimitiveArray::from_iter([-128i8, -1, 0, 1, 127])).unwrap())]
226    #[case::zigzag_i16(zigzag_encode(PrimitiveArray::from_iter([-1000i16, -100, 0, 100, 1000])).unwrap())]
227    #[case::zigzag_i32(zigzag_encode(PrimitiveArray::from_iter([-100000i32, -1000, 0, 1000, 100000])).unwrap())]
228    #[case::zigzag_i64(zigzag_encode(PrimitiveArray::from_iter([-1000000i64, -10000, 0, 10000, 1000000])).unwrap())]
229    // Nullable arrays
230    #[case::zigzag_nullable_i32(zigzag_encode(PrimitiveArray::from_option_iter([Some(-100i32), None, Some(0), Some(100), None])).unwrap())]
231    #[case::zigzag_nullable_i64(zigzag_encode(PrimitiveArray::from_option_iter([Some(-1000i64), None, Some(0), Some(1000), None])).unwrap())]
232    // Edge cases
233    #[case::zigzag_single(zigzag_encode(PrimitiveArray::from_iter([-42i32])).unwrap())]
234    #[case::zigzag_alternating(zigzag_encode(PrimitiveArray::from_iter([-1i32, 1, -2, 2, -3, 3])).unwrap())]
235    // Large arrays
236    #[case::zigzag_large_i32(zigzag_encode(PrimitiveArray::from_iter(-500..500)).unwrap())]
237    #[case::zigzag_large_i64(zigzag_encode(PrimitiveArray::from_iter((-1000..1000).map(|i| i as i64 * 100))).unwrap())]
238    fn test_zigzag_consistency(#[case] array: ZigZagArray) {
239        test_array_consistency(array.as_ref());
240    }
241
242    #[rstest]
243    #[case::zigzag_i8_basic(zigzag_encode(PrimitiveArray::from_iter([-10i8, -5, 0, 5, 10])).unwrap())]
244    #[case::zigzag_i16_basic(zigzag_encode(PrimitiveArray::from_iter([-100i16, -50, 0, 50, 100])).unwrap())]
245    #[case::zigzag_i32_basic(zigzag_encode(PrimitiveArray::from_iter([-1000i32, -500, 0, 500, 1000])).unwrap())]
246    #[case::zigzag_i64_basic(zigzag_encode(PrimitiveArray::from_iter([-10000i64, -5000, 0, 5000, 10000])).unwrap())]
247    #[case::zigzag_i32_large(zigzag_encode(PrimitiveArray::from_iter((-50..50).map(|i| i * 10))).unwrap())]
248    fn test_zigzag_binary_numeric(#[case] array: ZigZagArray) {
249        test_binary_numeric_array(array.into_array());
250    }
251}