vortex_zigzag/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::IntoArray;
9use vortex_array::compute::FilterKernel;
10use vortex_array::compute::FilterKernelAdapter;
11use vortex_array::compute::MaskKernel;
12use vortex_array::compute::MaskKernelAdapter;
13use vortex_array::compute::TakeKernel;
14use vortex_array::compute::TakeKernelAdapter;
15use vortex_array::compute::filter;
16use vortex_array::compute::mask;
17use vortex_array::compute::take;
18use vortex_array::register_kernel;
19use vortex_error::VortexResult;
20use vortex_mask::Mask;
21
22use crate::ZigZagArray;
23use crate::ZigZagVTable;
24
25impl FilterKernel for ZigZagVTable {
26    fn filter(&self, array: &ZigZagArray, mask: &Mask) -> VortexResult<ArrayRef> {
27        let encoded = filter(array.encoded(), mask)?;
28        Ok(ZigZagArray::try_new(encoded)?.into_array())
29    }
30}
31
32register_kernel!(FilterKernelAdapter(ZigZagVTable).lift());
33
34impl TakeKernel for ZigZagVTable {
35    fn take(&self, array: &ZigZagArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
36        let encoded = take(array.encoded(), indices)?;
37        Ok(ZigZagArray::try_new(encoded)?.into_array())
38    }
39}
40
41register_kernel!(TakeKernelAdapter(ZigZagVTable).lift());
42
43impl MaskKernel for ZigZagVTable {
44    fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult<ArrayRef> {
45        let encoded = mask(array.encoded(), filter_mask)?;
46        Ok(ZigZagArray::try_new(encoded)?.into_array())
47    }
48}
49
50register_kernel!(MaskKernelAdapter(ZigZagVTable).lift());
51
52pub(crate) trait ZigZagEncoded {
53    type Int: zigzag::ZigZag;
54}
55
56impl ZigZagEncoded for u8 {
57    type Int = i8;
58}
59
60impl ZigZagEncoded for u16 {
61    type Int = i16;
62}
63
64impl ZigZagEncoded for u32 {
65    type Int = i32;
66}
67
68impl ZigZagEncoded for u64 {
69    type Int = i64;
70}
71
72#[cfg(test)]
73mod tests {
74    use rstest::rstest;
75    use vortex_array::Array;
76    use vortex_array::ArrayRef;
77    use vortex_array::IntoArray;
78    use vortex_array::ToCanonical;
79    use vortex_array::arrays::PrimitiveArray;
80    use vortex_array::assert_arrays_eq;
81    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
82    use vortex_array::compute::conformance::consistency::test_array_consistency;
83    use vortex_array::compute::filter;
84    use vortex_array::compute::take;
85    use vortex_array::validity::Validity;
86    use vortex_array::vtable::ArrayVTableExt;
87    use vortex_buffer::BitBuffer;
88    use vortex_buffer::buffer;
89    use vortex_dtype::Nullability;
90    use vortex_scalar::Scalar;
91
92    use crate::ZigZagArray;
93    use crate::ZigZagVTable;
94    use crate::zigzag_encode;
95
96    #[test]
97    pub fn nullable_scalar_at() {
98        let zigzag = ZigZagVTable
99            .as_vtable()
100            .encode(
101                &PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid).to_canonical(),
102                None,
103            )
104            .unwrap()
105            .unwrap();
106        assert_eq!(
107            zigzag.scalar_at(1),
108            Scalar::primitive(-160, Nullability::Nullable)
109        );
110    }
111
112    #[test]
113    fn take_zigzag() {
114        let zigzag = ZigZagVTable
115            .as_vtable()
116            .encode(&buffer![-189, -160, 1].into_array().to_canonical(), None)
117            .unwrap()
118            .unwrap();
119
120        let indices = buffer![0, 2].into_array();
121        let actual = take(&zigzag, &indices).unwrap().to_primitive();
122        let expected = ZigZagVTable
123            .as_vtable()
124            .encode(&buffer![-189, 1].into_array().to_canonical(), None)
125            .unwrap()
126            .unwrap()
127            .to_primitive();
128        assert_arrays_eq!(actual, expected);
129    }
130
131    #[test]
132    fn filter_zigzag() {
133        let zigzag = ZigZagVTable
134            .as_vtable()
135            .encode(&buffer![-189, -160, 1].into_array().to_canonical(), None)
136            .unwrap()
137            .unwrap();
138        let filter_mask = BitBuffer::from(vec![true, false, true]).into();
139        let actual = filter(&zigzag, &filter_mask).unwrap().to_primitive();
140        let expected = ZigZagVTable
141            .as_vtable()
142            .encode(&buffer![-189, 1].into_array().to_canonical(), None)
143            .unwrap()
144            .unwrap()
145            .to_primitive();
146        assert_arrays_eq!(actual, expected);
147    }
148
149    #[test]
150    fn test_filter_conformance() {
151        use vortex_array::compute::conformance::filter::test_filter_conformance;
152
153        // Test with i32 values
154        let zigzag = ZigZagVTable
155            .as_vtable()
156            .encode(
157                &buffer![-189i32, -160, 1, 42, -73]
158                    .into_array()
159                    .to_canonical(),
160                None,
161            )
162            .unwrap()
163            .unwrap();
164        test_filter_conformance(zigzag.as_ref());
165
166        // Test with i64 values
167        let zigzag = ZigZagVTable
168            .as_vtable()
169            .encode(
170                &buffer![1000i64, -2000, 3000, -4000, 5000]
171                    .into_array()
172                    .to_canonical(),
173                None,
174            )
175            .unwrap()
176            .unwrap();
177        test_filter_conformance(zigzag.as_ref());
178
179        // Test with nullable values
180        let array =
181            PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]);
182        let zigzag = ZigZagVTable
183            .as_vtable()
184            .encode(&array.to_canonical(), None)
185            .unwrap()
186            .unwrap();
187        test_filter_conformance(zigzag.as_ref());
188    }
189
190    #[test]
191    fn test_mask_conformance() {
192        use vortex_array::compute::conformance::mask::test_mask_conformance;
193
194        // Test with i32 values
195        let zigzag = ZigZagVTable
196            .as_vtable()
197            .encode(
198                &buffer![-100i32, 200, -300, 400, -500]
199                    .into_array()
200                    .to_canonical(),
201                None,
202            )
203            .unwrap()
204            .unwrap();
205        test_mask_conformance(zigzag.as_ref());
206
207        // Test with i8 values
208        let zigzag = ZigZagVTable
209            .as_vtable()
210            .encode(
211                &buffer![-127i8, 0, 127, -1, 1].into_array().to_canonical(),
212                None,
213            )
214            .unwrap()
215            .unwrap();
216        test_mask_conformance(zigzag.as_ref());
217    }
218
219    #[rstest]
220    #[case(buffer![-189i32, -160, 1, 42, -73].into_array())]
221    #[case(buffer![1000i64, -2000, 3000, -4000, 5000].into_array())]
222    #[case(PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]).into_array())]
223    #[case(buffer![42i32].into_array())]
224    fn test_take_zigzag_conformance(#[case] array: ArrayRef) {
225        use vortex_array::compute::conformance::take::test_take_conformance;
226
227        let zigzag = ZigZagVTable
228            .as_vtable()
229            .encode(&array.to_canonical(), None)
230            .unwrap()
231            .unwrap();
232        test_take_conformance(zigzag.as_ref());
233    }
234
235    #[rstest]
236    // Basic ZigZag arrays
237    #[case::zigzag_i8(zigzag_encode(PrimitiveArray::from_iter([-128i8, -1, 0, 1, 127])).unwrap())]
238    #[case::zigzag_i16(zigzag_encode(PrimitiveArray::from_iter([-1000i16, -100, 0, 100, 1000])).unwrap())]
239    #[case::zigzag_i32(zigzag_encode(PrimitiveArray::from_iter([-100000i32, -1000, 0, 1000, 100000])).unwrap())]
240    #[case::zigzag_i64(zigzag_encode(PrimitiveArray::from_iter([-1000000i64, -10000, 0, 10000, 1000000])).unwrap())]
241    // Nullable arrays
242    #[case::zigzag_nullable_i32(zigzag_encode(PrimitiveArray::from_option_iter([Some(-100i32), None, Some(0), Some(100), None])).unwrap())]
243    #[case::zigzag_nullable_i64(zigzag_encode(PrimitiveArray::from_option_iter([Some(-1000i64), None, Some(0), Some(1000), None])).unwrap())]
244    // Edge cases
245    #[case::zigzag_single(zigzag_encode(PrimitiveArray::from_iter([-42i32])).unwrap())]
246    #[case::zigzag_alternating(zigzag_encode(PrimitiveArray::from_iter([-1i32, 1, -2, 2, -3, 3])).unwrap())]
247    // Large arrays
248    #[case::zigzag_large_i32(zigzag_encode(PrimitiveArray::from_iter(-500..500)).unwrap())]
249    #[case::zigzag_large_i64(zigzag_encode(PrimitiveArray::from_iter((-1000..1000).map(|i| i as i64 * 100))).unwrap())]
250    fn test_zigzag_consistency(#[case] array: ZigZagArray) {
251        test_array_consistency(array.as_ref());
252    }
253
254    #[rstest]
255    #[case::zigzag_i8_basic(zigzag_encode(PrimitiveArray::from_iter([-10i8, -5, 0, 5, 10])).unwrap())]
256    #[case::zigzag_i16_basic(zigzag_encode(PrimitiveArray::from_iter([-100i16, -50, 0, 50, 100])).unwrap())]
257    #[case::zigzag_i32_basic(zigzag_encode(PrimitiveArray::from_iter([-1000i32, -500, 0, 500, 1000])).unwrap())]
258    #[case::zigzag_i64_basic(zigzag_encode(PrimitiveArray::from_iter([-10000i64, -5000, 0, 5000, 10000])).unwrap())]
259    #[case::zigzag_i32_large(zigzag_encode(PrimitiveArray::from_iter((-50..50).map(|i| i * 10))).unwrap())]
260    fn test_zigzag_binary_numeric(#[case] array: ZigZagArray) {
261        test_binary_numeric_array(array.into_array());
262    }
263}