vortex_zigzag/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5
6use vortex_array::compute::{
7    FilterKernel, FilterKernelAdapter, MaskKernel, MaskKernelAdapter, TakeKernel,
8    TakeKernelAdapter, filter, mask, take,
9};
10use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
11use vortex_error::VortexResult;
12use vortex_mask::Mask;
13
14use crate::{ZigZagArray, ZigZagVTable};
15
16impl FilterKernel for ZigZagVTable {
17    fn filter(&self, array: &ZigZagArray, mask: &Mask) -> VortexResult<ArrayRef> {
18        let encoded = filter(array.encoded(), mask)?;
19        Ok(ZigZagArray::try_new(encoded)?.into_array())
20    }
21}
22
23register_kernel!(FilterKernelAdapter(ZigZagVTable).lift());
24
25impl TakeKernel for ZigZagVTable {
26    fn take(&self, array: &ZigZagArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
27        let encoded = take(array.encoded(), indices)?;
28        Ok(ZigZagArray::try_new(encoded)?.into_array())
29    }
30}
31
32register_kernel!(TakeKernelAdapter(ZigZagVTable).lift());
33
34impl MaskKernel for ZigZagVTable {
35    fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult<ArrayRef> {
36        let encoded = mask(array.encoded(), filter_mask)?;
37        Ok(ZigZagArray::try_new(encoded)?.into_array())
38    }
39}
40
41register_kernel!(MaskKernelAdapter(ZigZagVTable).lift());
42
43pub(crate) trait ZigZagEncoded {
44    type Int: zigzag::ZigZag;
45}
46
47impl ZigZagEncoded for u8 {
48    type Int = i8;
49}
50
51impl ZigZagEncoded for u16 {
52    type Int = i16;
53}
54
55impl ZigZagEncoded for u32 {
56    type Int = i32;
57}
58
59impl ZigZagEncoded for u64 {
60    type Int = i64;
61}
62
63#[cfg(test)]
64mod tests {
65    use rstest::rstest;
66    use vortex_array::arrays::{BooleanBuffer, PrimitiveArray};
67    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
68    use vortex_array::compute::conformance::consistency::test_array_consistency;
69    use vortex_array::compute::{filter, take};
70    use vortex_array::validity::Validity;
71    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
72    use vortex_buffer::buffer;
73    use vortex_dtype::Nullability;
74    use vortex_scalar::Scalar;
75
76    use crate::{ZigZagArray, ZigZagEncoding, zigzag_encode};
77
78    #[test]
79    pub fn nullable_scalar_at() {
80        let zigzag = ZigZagEncoding
81            .encode(
82                &PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid)
83                    .to_canonical()
84                    .unwrap(),
85                None,
86            )
87            .unwrap()
88            .unwrap();
89        assert_eq!(
90            zigzag.scalar_at(1).unwrap(),
91            Scalar::primitive(-160, Nullability::Nullable)
92        );
93    }
94
95    #[test]
96    fn take_zigzag() {
97        let zigzag = ZigZagEncoding
98            .encode(
99                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
100                None,
101            )
102            .unwrap()
103            .unwrap();
104
105        let indices = buffer![0, 2].into_array();
106        let actual = take(&zigzag, &indices).unwrap().to_primitive().unwrap();
107        let expected = ZigZagEncoding
108            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
109            .unwrap()
110            .unwrap()
111            .to_primitive()
112            .unwrap();
113        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
114    }
115
116    #[test]
117    fn filter_zigzag() {
118        let zigzag = ZigZagEncoding
119            .encode(
120                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
121                None,
122            )
123            .unwrap()
124            .unwrap();
125        let filter_mask = BooleanBuffer::from(vec![true, false, true]).into();
126        let actual = filter(&zigzag, &filter_mask)
127            .unwrap()
128            .to_primitive()
129            .unwrap();
130        let expected = ZigZagEncoding
131            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
132            .unwrap()
133            .unwrap()
134            .to_primitive()
135            .unwrap();
136        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
137    }
138
139    #[test]
140    fn test_filter_conformance() {
141        use vortex_array::compute::conformance::filter::test_filter_conformance;
142
143        // Test with i32 values
144        let zigzag = ZigZagEncoding
145            .encode(
146                &buffer![-189i32, -160, 1, 42, -73]
147                    .into_array()
148                    .to_canonical()
149                    .unwrap(),
150                None,
151            )
152            .unwrap()
153            .unwrap();
154        test_filter_conformance(zigzag.as_ref());
155
156        // Test with i64 values
157        let zigzag = ZigZagEncoding
158            .encode(
159                &buffer![1000i64, -2000, 3000, -4000, 5000]
160                    .into_array()
161                    .to_canonical()
162                    .unwrap(),
163                None,
164            )
165            .unwrap()
166            .unwrap();
167        test_filter_conformance(zigzag.as_ref());
168
169        // Test with nullable values
170        let array =
171            PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]);
172        let zigzag = ZigZagEncoding
173            .encode(&array.to_canonical().unwrap(), None)
174            .unwrap()
175            .unwrap();
176        test_filter_conformance(zigzag.as_ref());
177    }
178
179    #[test]
180    fn test_mask_conformance() {
181        use vortex_array::compute::conformance::mask::test_mask_conformance;
182
183        // Test with i32 values
184        let zigzag = ZigZagEncoding
185            .encode(
186                &buffer![-100i32, 200, -300, 400, -500]
187                    .into_array()
188                    .to_canonical()
189                    .unwrap(),
190                None,
191            )
192            .unwrap()
193            .unwrap();
194        test_mask_conformance(zigzag.as_ref());
195
196        // Test with i8 values
197        let zigzag = ZigZagEncoding
198            .encode(
199                &buffer![-127i8, 0, 127, -1, 1]
200                    .into_array()
201                    .to_canonical()
202                    .unwrap(),
203                None,
204            )
205            .unwrap()
206            .unwrap();
207        test_mask_conformance(zigzag.as_ref());
208    }
209
210    #[rstest]
211    #[case(buffer![-189i32, -160, 1, 42, -73].into_array())]
212    #[case(buffer![1000i64, -2000, 3000, -4000, 5000].into_array())]
213    #[case(PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]).into_array())]
214    #[case(buffer![42i32].into_array())]
215    fn test_take_zigzag_conformance(#[case] array: ArrayRef) {
216        use vortex_array::compute::conformance::take::test_take_conformance;
217
218        let zigzag = ZigZagEncoding
219            .encode(&array.to_canonical().unwrap(), None)
220            .unwrap()
221            .unwrap();
222        test_take_conformance(zigzag.as_ref());
223    }
224
225    #[rstest]
226    // Basic ZigZag arrays
227    #[case::zigzag_i8(zigzag_encode(PrimitiveArray::from_iter([-128i8, -1, 0, 1, 127])).unwrap())]
228    #[case::zigzag_i16(zigzag_encode(PrimitiveArray::from_iter([-1000i16, -100, 0, 100, 1000])).unwrap())]
229    #[case::zigzag_i32(zigzag_encode(PrimitiveArray::from_iter([-100000i32, -1000, 0, 1000, 100000])).unwrap())]
230    #[case::zigzag_i64(zigzag_encode(PrimitiveArray::from_iter([-1000000i64, -10000, 0, 10000, 1000000])).unwrap())]
231    // Nullable arrays
232    #[case::zigzag_nullable_i32(zigzag_encode(PrimitiveArray::from_option_iter([Some(-100i32), None, Some(0), Some(100), None])).unwrap())]
233    #[case::zigzag_nullable_i64(zigzag_encode(PrimitiveArray::from_option_iter([Some(-1000i64), None, Some(0), Some(1000), None])).unwrap())]
234    // Edge cases
235    #[case::zigzag_single(zigzag_encode(PrimitiveArray::from_iter([-42i32])).unwrap())]
236    #[case::zigzag_alternating(zigzag_encode(PrimitiveArray::from_iter([-1i32, 1, -2, 2, -3, 3])).unwrap())]
237    // Large arrays
238    #[case::zigzag_large_i32(zigzag_encode(PrimitiveArray::from_iter(-500..500)).unwrap())]
239    #[case::zigzag_large_i64(zigzag_encode(PrimitiveArray::from_iter((-1000..1000).map(|i| i as i64 * 100))).unwrap())]
240    fn test_zigzag_consistency(#[case] array: ZigZagArray) {
241        test_array_consistency(array.as_ref());
242    }
243
244    #[rstest]
245    #[case::zigzag_i8_basic(zigzag_encode(PrimitiveArray::from_iter([-10i8, -5, 0, 5, 10])).unwrap())]
246    #[case::zigzag_i16_basic(zigzag_encode(PrimitiveArray::from_iter([-100i16, -50, 0, 50, 100])).unwrap())]
247    #[case::zigzag_i32_basic(zigzag_encode(PrimitiveArray::from_iter([-1000i32, -500, 0, 500, 1000])).unwrap())]
248    #[case::zigzag_i64_basic(zigzag_encode(PrimitiveArray::from_iter([-10000i64, -5000, 0, 5000, 10000])).unwrap())]
249    #[case::zigzag_i32_large(zigzag_encode(PrimitiveArray::from_iter((-50..50).map(|i| i * 10))).unwrap())]
250    fn test_zigzag_binary_numeric(#[case] array: ZigZagArray) {
251        test_binary_numeric_array(array.into_array());
252    }
253}