Skip to main content

vortex_zigzag/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::dict::TakeExecute;
11use vortex_array::arrays::filter::FilterReduce;
12use vortex_array::arrays::scalar_fn::ScalarFnFactoryExt;
13use vortex_array::scalar_fn::EmptyOptions;
14use vortex_array::scalar_fn::fns::mask::Mask as MaskExpr;
15use vortex_array::scalar_fn::fns::mask::MaskReduce;
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18
19use crate::ZigZag;
20use crate::array::ZigZagArrayExt;
21
22impl FilterReduce for ZigZag {
23    fn filter(array: ArrayView<'_, Self>, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
24        let encoded = array.encoded().filter(mask.clone())?;
25        Ok(Some(ZigZag::try_new(encoded)?.into_array()))
26    }
27}
28
29impl TakeExecute for ZigZag {
30    fn take(
31        array: ArrayView<'_, Self>,
32        indices: &ArrayRef,
33        _ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        let encoded = array.encoded().take(indices.clone())?;
36        Ok(Some(ZigZag::try_new(encoded)?.into_array()))
37    }
38}
39
40impl MaskReduce for ZigZag {
41    fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
42        let masked_encoded = MaskExpr.try_new_array(
43            array.encoded().len(),
44            EmptyOptions,
45            [array.encoded().clone(), mask.clone()],
46        )?;
47        Ok(Some(ZigZag::try_new(masked_encoded)?.into_array()))
48    }
49}
50
51pub(crate) trait ZigZagEncoded {
52    type Int: zigzag::ZigZag;
53}
54
55impl ZigZagEncoded for u8 {
56    type Int = i8;
57}
58
59impl ZigZagEncoded for u16 {
60    type Int = i16;
61}
62
63impl ZigZagEncoded for u32 {
64    type Int = i32;
65}
66
67impl ZigZagEncoded for u64 {
68    type Int = i64;
69}
70
71#[cfg(test)]
72mod tests {
73    use rstest::rstest;
74    use vortex_array::ArrayRef;
75    use vortex_array::IntoArray;
76    use vortex_array::LEGACY_SESSION;
77    use vortex_array::ToCanonical;
78    use vortex_array::VortexSessionExecute;
79    use vortex_array::arrays::PrimitiveArray;
80    use vortex_array::assert_arrays_eq;
81    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
82    use vortex_array::compute::conformance::consistency::test_array_consistency;
83    use vortex_array::dtype::Nullability;
84    use vortex_array::scalar::Scalar;
85    use vortex_array::validity::Validity;
86    use vortex_buffer::BitBuffer;
87    use vortex_buffer::buffer;
88    use vortex_error::VortexResult;
89
90    use crate::ZigZagArray;
91    use crate::zigzag_encode;
92
93    #[test]
94    pub fn nullable_scalar_at() -> VortexResult<()> {
95        let zigzag = zigzag_encode(
96            PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid).as_view(),
97        )?;
98        assert_eq!(
99            zigzag.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
100            Scalar::primitive(-160, Nullability::Nullable)
101        );
102        Ok(())
103    }
104
105    #[test]
106    fn take_zigzag() -> VortexResult<()> {
107        let zigzag = zigzag_encode(
108            PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid).as_view(),
109        )?;
110
111        let indices = buffer![0, 2].into_array();
112        let actual = zigzag.take(indices).unwrap();
113        let expected =
114            zigzag_encode(PrimitiveArray::new(buffer![-189, 1], Validity::AllValid).as_view())?
115                .into_array();
116        assert_arrays_eq!(actual, expected);
117        Ok(())
118    }
119
120    #[test]
121    fn filter_zigzag() -> VortexResult<()> {
122        let zigzag = zigzag_encode(
123            PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid).as_view(),
124        )?;
125
126        let filter_mask = BitBuffer::from(vec![true, false, true]).into();
127        let actual = zigzag.filter(filter_mask).unwrap();
128        let expected =
129            zigzag_encode(PrimitiveArray::new(buffer![-189, 1], Validity::AllValid).as_view())?
130                .into_array();
131        assert_arrays_eq!(actual, expected);
132        Ok(())
133    }
134
135    #[test]
136    fn test_filter_conformance() -> VortexResult<()> {
137        use vortex_array::compute::conformance::filter::test_filter_conformance;
138
139        // Test with i32 values
140        let zigzag = zigzag_encode(
141            PrimitiveArray::new(buffer![-189i32, -160, 1, 42, -73], Validity::AllValid).as_view(),
142        )?;
143        test_filter_conformance(&zigzag.into_array());
144
145        // Test with i64 values
146        let zigzag = zigzag_encode(
147            PrimitiveArray::new(
148                buffer![1000i64, -2000, 3000, -4000, 5000],
149                Validity::AllValid,
150            )
151            .as_view(),
152        )?;
153        test_filter_conformance(&zigzag.into_array());
154
155        // Test with nullable values
156        let array =
157            PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]);
158        let zigzag = zigzag_encode(array.as_view())?;
159        test_filter_conformance(&zigzag.into_array());
160        Ok(())
161    }
162
163    #[test]
164    fn test_mask_conformance() -> VortexResult<()> {
165        use vortex_array::compute::conformance::mask::test_mask_conformance;
166
167        // Test with i32 values
168        let zigzag = zigzag_encode(
169            PrimitiveArray::new(buffer![-100i32, 200, -300, 400, -500], Validity::AllValid)
170                .as_view(),
171        )?;
172        test_mask_conformance(&zigzag.into_array());
173
174        // Test with i8 values
175        let zigzag = zigzag_encode(
176            PrimitiveArray::new(buffer![-127i8, 0, 127, -1, 1], Validity::AllValid).as_view(),
177        )?;
178        test_mask_conformance(&zigzag.into_array());
179        Ok(())
180    }
181
182    #[rstest]
183    #[case(buffer![-189i32, -160, 1, 42, -73].into_array())]
184    #[case(buffer![1000i64, -2000, 3000, -4000, 5000].into_array())]
185    #[case(PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]).into_array()
186    )]
187    #[case(buffer![42i32].into_array())]
188    fn test_take_zigzag_conformance(#[case] array: ArrayRef) -> VortexResult<()> {
189        use vortex_array::compute::conformance::take::test_take_conformance;
190
191        let zigzag = zigzag_encode(array.to_primitive().as_view())?;
192        test_take_conformance(&zigzag.into_array());
193        Ok(())
194    }
195
196    #[rstest]
197    // Basic ZigZag arrays
198    #[case::zigzag_i8(zigzag_encode(PrimitiveArray::from_iter([-128i8, -1, 0, 1, 127]).as_view()).unwrap())]
199    #[case::zigzag_i16(zigzag_encode(PrimitiveArray::from_iter([-1000i16, -100, 0, 100, 1000]).as_view()).unwrap())]
200    #[case::zigzag_i32(zigzag_encode(PrimitiveArray::from_iter([-100000i32, -1000, 0, 1000, 100000]).as_view()).unwrap())]
201    #[case::zigzag_i64(zigzag_encode(PrimitiveArray::from_iter([-1000000i64, -10000, 0, 10000, 1000000]).as_view()).unwrap())]
202    // Nullable arrays
203    #[case::zigzag_nullable_i32(zigzag_encode(PrimitiveArray::from_option_iter([Some(-100i32), None, Some(0), Some(100), None]).as_view()).unwrap())]
204    #[case::zigzag_nullable_i64(zigzag_encode(PrimitiveArray::from_option_iter([Some(-1000i64), None, Some(0), Some(1000), None]).as_view()).unwrap())]
205    // Edge cases
206    #[case::zigzag_single(zigzag_encode(PrimitiveArray::from_iter([-42i32]).as_view()).unwrap())]
207    #[case::zigzag_alternating(zigzag_encode(PrimitiveArray::from_iter([-1i32, 1, -2, 2, -3, 3]).as_view()).unwrap())]
208    // Large arrays
209    #[case::zigzag_large_i32(zigzag_encode(PrimitiveArray::from_iter(-500..500).as_view()).unwrap())]
210    #[case::zigzag_large_i64(zigzag_encode(PrimitiveArray::from_iter((-1000..1000).map(|i| i as i64 * 100)).as_view()).unwrap())]
211    fn test_zigzag_consistency(#[case] array: ZigZagArray) {
212        test_array_consistency(&array.into_array());
213    }
214
215    #[rstest]
216    #[case::zigzag_i8_basic(zigzag_encode(PrimitiveArray::from_iter([-10i8, -5, 0, 5, 10]).as_view()).unwrap())]
217    #[case::zigzag_i16_basic(zigzag_encode(PrimitiveArray::from_iter([-100i16, -50, 0, 50, 100]).as_view()).unwrap())]
218    #[case::zigzag_i32_basic(zigzag_encode(PrimitiveArray::from_iter([-1000i32, -500, 0, 500, 1000]).as_view()).unwrap())]
219    #[case::zigzag_i64_basic(zigzag_encode(PrimitiveArray::from_iter([-10000i64, -5000, 0, 5000, 10000]).as_view()).unwrap())]
220    #[case::zigzag_i32_large(zigzag_encode(PrimitiveArray::from_iter((-50..50).map(|i| i * 10)).as_view()).unwrap())]
221    fn test_zigzag_binary_numeric(#[case] array: ZigZagArray) {
222        test_binary_numeric_array(array.into_array());
223    }
224}