Skip to main content

vortex_zigzag/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::FilterReduce;
11use vortex_array::arrays::ScalarFnArrayExt;
12use vortex_array::arrays::TakeExecute;
13use vortex_array::scalar_fn::EmptyOptions;
14use vortex_array::scalar_fn::fns::mask::Mask as MaskExpr;
15use vortex_array::scalar_fn::fns::mask::MaskReduce;
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18
19use crate::ZigZagArray;
20use crate::ZigZagVTable;
21
22impl FilterReduce for ZigZagVTable {
23    fn filter(array: &ZigZagArray, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
24        let encoded = array.encoded().filter(mask.clone())?;
25        Ok(Some(ZigZagArray::try_new(encoded)?.into_array()))
26    }
27}
28
29impl TakeExecute for ZigZagVTable {
30    fn take(
31        array: &ZigZagArray,
32        indices: &dyn Array,
33        _ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        let encoded = array.encoded().take(indices.to_array())?;
36        Ok(Some(ZigZagArray::try_new(encoded)?.into_array()))
37    }
38}
39
40impl MaskReduce for ZigZagVTable {
41    fn mask(array: &ZigZagArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
42        let masked_encoded = MaskExpr.try_new_array(
43            array.encoded().len(),
44            EmptyOptions,
45            [array.encoded().clone(), mask.clone()],
46        )?;
47        Ok(Some(ZigZagArray::try_new(masked_encoded)?.into_array()))
48    }
49}
50
51pub(crate) trait ZigZagEncoded {
52    type Int: zigzag::ZigZag;
53}
54
55impl ZigZagEncoded for u8 {
56    type Int = i8;
57}
58
59impl ZigZagEncoded for u16 {
60    type Int = i16;
61}
62
63impl ZigZagEncoded for u32 {
64    type Int = i32;
65}
66
67impl ZigZagEncoded for u64 {
68    type Int = i64;
69}
70
71#[cfg(test)]
72mod tests {
73    use rstest::rstest;
74    use vortex_array::Array;
75    use vortex_array::ArrayRef;
76    use vortex_array::IntoArray;
77    use vortex_array::ToCanonical;
78    use vortex_array::arrays::PrimitiveArray;
79    use vortex_array::assert_arrays_eq;
80    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
81    use vortex_array::compute::conformance::consistency::test_array_consistency;
82    use vortex_array::dtype::Nullability;
83    use vortex_array::scalar::Scalar;
84    use vortex_array::validity::Validity;
85    use vortex_buffer::BitBuffer;
86    use vortex_buffer::buffer;
87    use vortex_error::VortexResult;
88
89    use crate::ZigZagArray;
90    use crate::zigzag_encode;
91
92    #[test]
93    pub fn nullable_scalar_at() -> VortexResult<()> {
94        let zigzag = zigzag_encode(PrimitiveArray::new(
95            buffer![-189, -160, 1],
96            Validity::AllValid,
97        ))?;
98        assert_eq!(
99            zigzag.scalar_at(1)?,
100            Scalar::primitive(-160, Nullability::Nullable)
101        );
102        Ok(())
103    }
104
105    #[test]
106    fn take_zigzag() -> VortexResult<()> {
107        let zigzag = zigzag_encode(PrimitiveArray::new(
108            buffer![-189, -160, 1],
109            Validity::AllValid,
110        ))?;
111
112        let indices = buffer![0, 2].into_array();
113        let actual = zigzag.take(indices.to_array()).unwrap();
114        let expected = zigzag_encode(PrimitiveArray::new(buffer![-189, 1], Validity::AllValid))?;
115        assert_arrays_eq!(actual, expected);
116        Ok(())
117    }
118
119    #[test]
120    fn filter_zigzag() -> VortexResult<()> {
121        let zigzag = zigzag_encode(PrimitiveArray::new(
122            buffer![-189, -160, 1],
123            Validity::AllValid,
124        ))?;
125
126        let filter_mask = BitBuffer::from(vec![true, false, true]).into();
127        let actual = zigzag.filter(filter_mask).unwrap();
128        let expected =
129            zigzag_encode(PrimitiveArray::new(buffer![-189, 1], Validity::AllValid))?.into_array();
130        assert_arrays_eq!(actual, expected);
131        Ok(())
132    }
133
134    #[test]
135    fn test_filter_conformance() -> VortexResult<()> {
136        use vortex_array::compute::conformance::filter::test_filter_conformance;
137
138        // Test with i32 values
139        let zigzag = zigzag_encode(PrimitiveArray::new(
140            buffer![-189i32, -160, 1, 42, -73],
141            Validity::AllValid,
142        ))?;
143        test_filter_conformance(zigzag.as_ref());
144
145        // Test with i64 values
146        let zigzag = zigzag_encode(PrimitiveArray::new(
147            buffer![1000i64, -2000, 3000, -4000, 5000],
148            Validity::AllValid,
149        ))?;
150        test_filter_conformance(zigzag.as_ref());
151
152        // Test with nullable values
153        let array =
154            PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]);
155        let zigzag = zigzag_encode(array)?;
156        test_filter_conformance(zigzag.as_ref());
157        Ok(())
158    }
159
160    #[test]
161    fn test_mask_conformance() -> VortexResult<()> {
162        use vortex_array::compute::conformance::mask::test_mask_conformance;
163
164        // Test with i32 values
165        let zigzag = zigzag_encode(PrimitiveArray::new(
166            buffer![-100i32, 200, -300, 400, -500],
167            Validity::AllValid,
168        ))?;
169        test_mask_conformance(zigzag.as_ref());
170
171        // Test with i8 values
172        let zigzag = zigzag_encode(PrimitiveArray::new(
173            buffer![-127i8, 0, 127, -1, 1],
174            Validity::AllValid,
175        ))?;
176        test_mask_conformance(zigzag.as_ref());
177        Ok(())
178    }
179
180    #[rstest]
181    #[case(buffer![-189i32, -160, 1, 42, -73].into_array())]
182    #[case(buffer![1000i64, -2000, 3000, -4000, 5000].into_array())]
183    #[case(PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]).into_array()
184    )]
185    #[case(buffer![42i32].into_array())]
186    fn test_take_zigzag_conformance(#[case] array: ArrayRef) -> VortexResult<()> {
187        use vortex_array::compute::conformance::take::test_take_conformance;
188
189        let zigzag = zigzag_encode(array.to_primitive())?;
190        test_take_conformance(zigzag.as_ref());
191        Ok(())
192    }
193
194    #[rstest]
195    // Basic ZigZag arrays
196    #[case::zigzag_i8(zigzag_encode(PrimitiveArray::from_iter([-128i8, -1, 0, 1, 127])).unwrap())]
197    #[case::zigzag_i16(zigzag_encode(PrimitiveArray::from_iter([-1000i16, -100, 0, 100, 1000])).unwrap()
198    )]
199    #[case::zigzag_i32(zigzag_encode(PrimitiveArray::from_iter([-100000i32, -1000, 0, 1000, 100000])).unwrap()
200    )]
201    #[case::zigzag_i64(zigzag_encode(PrimitiveArray::from_iter([-1000000i64, -10000, 0, 10000, 1000000])).unwrap()
202    )]
203    // Nullable arrays
204    #[case::zigzag_nullable_i32(zigzag_encode(PrimitiveArray::from_option_iter([Some(-100i32), None, Some(0), Some(100), None])).unwrap()
205    )]
206    #[case::zigzag_nullable_i64(zigzag_encode(PrimitiveArray::from_option_iter([Some(-1000i64), None, Some(0), Some(1000), None])).unwrap()
207    )]
208    // Edge cases
209    #[case::zigzag_single(zigzag_encode(PrimitiveArray::from_iter([-42i32])).unwrap())]
210    #[case::zigzag_alternating(zigzag_encode(PrimitiveArray::from_iter([-1i32, 1, -2, 2, -3, 3])).unwrap()
211    )]
212    // Large arrays
213    #[case::zigzag_large_i32(zigzag_encode(PrimitiveArray::from_iter(-500..500)).unwrap())]
214    #[case::zigzag_large_i64(zigzag_encode(PrimitiveArray::from_iter((-1000..1000).map(|i| i as i64 * 100))).unwrap()
215    )]
216    fn test_zigzag_consistency(#[case] array: ZigZagArray) {
217        test_array_consistency(array.as_ref());
218    }
219
220    #[rstest]
221    #[case::zigzag_i8_basic(zigzag_encode(PrimitiveArray::from_iter([-10i8, -5, 0, 5, 10])).unwrap()
222    )]
223    #[case::zigzag_i16_basic(zigzag_encode(PrimitiveArray::from_iter([-100i16, -50, 0, 50, 100])).unwrap()
224    )]
225    #[case::zigzag_i32_basic(zigzag_encode(PrimitiveArray::from_iter([-1000i32, -500, 0, 500, 1000])).unwrap()
226    )]
227    #[case::zigzag_i64_basic(zigzag_encode(PrimitiveArray::from_iter([-10000i64, -5000, 0, 5000, 10000])).unwrap()
228    )]
229    #[case::zigzag_i32_large(zigzag_encode(PrimitiveArray::from_iter((-50..50).map(|i| i * 10))).unwrap()
230    )]
231    fn test_zigzag_binary_numeric(#[case] array: ZigZagArray) {
232        test_binary_numeric_array(array.into_array());
233    }
234}