Skip to main content

vortex_zigzag/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::dict::TakeExecute;
11use vortex_array::arrays::filter::FilterReduce;
12use vortex_array::arrays::scalar_fn::ScalarFnFactoryExt;
13use vortex_array::scalar_fn::EmptyOptions;
14use vortex_array::scalar_fn::fns::mask::Mask as MaskExpr;
15use vortex_array::scalar_fn::fns::mask::MaskReduce;
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18
19use crate::ZigZag;
20use crate::array::ZigZagArrayExt;
21
22impl FilterReduce for ZigZag {
23    fn filter(array: ArrayView<'_, Self>, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
24        let encoded = array.encoded().filter(mask.clone())?;
25        Ok(Some(ZigZag::try_new(encoded)?.into_array()))
26    }
27}
28
29impl TakeExecute for ZigZag {
30    fn take(
31        array: ArrayView<'_, Self>,
32        indices: &ArrayRef,
33        _ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        let encoded = array.encoded().take(indices.clone())?;
36        Ok(Some(ZigZag::try_new(encoded)?.into_array()))
37    }
38}
39
40impl MaskReduce for ZigZag {
41    fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
42        let masked_encoded = MaskExpr.try_new_array(
43            array.encoded().len(),
44            EmptyOptions,
45            [array.encoded().clone(), mask.clone()],
46        )?;
47        Ok(Some(ZigZag::try_new(masked_encoded)?.into_array()))
48    }
49}
50
51pub(crate) trait ZigZagEncoded {
52    type Int: zigzag::ZigZag;
53}
54
55impl ZigZagEncoded for u8 {
56    type Int = i8;
57}
58
59impl ZigZagEncoded for u16 {
60    type Int = i16;
61}
62
63impl ZigZagEncoded for u32 {
64    type Int = i32;
65}
66
67impl ZigZagEncoded for u64 {
68    type Int = i64;
69}
70
71#[cfg(test)]
72mod tests {
73    use rstest::rstest;
74    use vortex_array::ArrayRef;
75    use vortex_array::IntoArray;
76    use vortex_array::LEGACY_SESSION;
77    use vortex_array::VortexSessionExecute;
78    use vortex_array::arrays::PrimitiveArray;
79    use vortex_array::assert_arrays_eq;
80    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
81    use vortex_array::compute::conformance::consistency::test_array_consistency;
82    use vortex_array::dtype::Nullability;
83    use vortex_array::scalar::Scalar;
84    use vortex_array::validity::Validity;
85    use vortex_buffer::BitBuffer;
86    use vortex_buffer::buffer;
87    use vortex_error::VortexResult;
88
89    use crate::ZigZagArray;
90    use crate::zigzag_encode;
91
92    #[test]
93    pub fn nullable_scalar_at() -> VortexResult<()> {
94        let zigzag = zigzag_encode(
95            PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid).as_view(),
96        )?;
97        assert_eq!(
98            zigzag.execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())?,
99            Scalar::primitive(-160, Nullability::Nullable)
100        );
101        Ok(())
102    }
103
104    #[test]
105    fn take_zigzag() -> VortexResult<()> {
106        let zigzag = zigzag_encode(
107            PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid).as_view(),
108        )?;
109
110        let indices = buffer![0, 2].into_array();
111        let actual = zigzag.take(indices)?;
112        let expected =
113            zigzag_encode(PrimitiveArray::new(buffer![-189, 1], Validity::AllValid).as_view())?
114                .into_array();
115        assert_arrays_eq!(actual, expected);
116        Ok(())
117    }
118
119    #[test]
120    fn filter_zigzag() -> VortexResult<()> {
121        let zigzag = zigzag_encode(
122            PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid).as_view(),
123        )?;
124
125        let filter_mask = BitBuffer::from(vec![true, false, true]).into();
126        let actual = zigzag.filter(filter_mask)?;
127        let expected =
128            zigzag_encode(PrimitiveArray::new(buffer![-189, 1], Validity::AllValid).as_view())?
129                .into_array();
130        assert_arrays_eq!(actual, expected);
131        Ok(())
132    }
133
134    #[test]
135    fn test_filter_conformance() -> VortexResult<()> {
136        use vortex_array::compute::conformance::filter::test_filter_conformance;
137
138        // Test with i32 values
139        let zigzag = zigzag_encode(
140            PrimitiveArray::new(buffer![-189i32, -160, 1, 42, -73], Validity::AllValid).as_view(),
141        )?;
142        test_filter_conformance(&zigzag.into_array());
143
144        // Test with i64 values
145        let zigzag = zigzag_encode(
146            PrimitiveArray::new(
147                buffer![1000i64, -2000, 3000, -4000, 5000],
148                Validity::AllValid,
149            )
150            .as_view(),
151        )?;
152        test_filter_conformance(&zigzag.into_array());
153
154        // Test with nullable values
155        let array =
156            PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]);
157        let zigzag = zigzag_encode(array.as_view())?;
158        test_filter_conformance(&zigzag.into_array());
159        Ok(())
160    }
161
162    #[test]
163    fn test_mask_conformance() -> VortexResult<()> {
164        use vortex_array::compute::conformance::mask::test_mask_conformance;
165
166        // Test with i32 values
167        let zigzag = zigzag_encode(
168            PrimitiveArray::new(buffer![-100i32, 200, -300, 400, -500], Validity::AllValid)
169                .as_view(),
170        )?;
171        test_mask_conformance(&zigzag.into_array());
172
173        // Test with i8 values
174        let zigzag = zigzag_encode(
175            PrimitiveArray::new(buffer![-127i8, 0, 127, -1, 1], Validity::AllValid).as_view(),
176        )?;
177        test_mask_conformance(&zigzag.into_array());
178        Ok(())
179    }
180
181    #[rstest]
182    #[case(buffer![-189i32, -160, 1, 42, -73].into_array())]
183    #[case(buffer![1000i64, -2000, 3000, -4000, 5000].into_array())]
184    #[case(PrimitiveArray::from_option_iter([Some(-10i16), None, Some(20), Some(-30), None]).into_array()
185    )]
186    #[case(buffer![42i32].into_array())]
187    fn test_take_zigzag_conformance(#[case] array: ArrayRef) -> VortexResult<()> {
188        use vortex_array::compute::conformance::take::test_take_conformance;
189
190        let mut ctx = LEGACY_SESSION.create_execution_ctx();
191        let array_primitive = array.execute::<PrimitiveArray>(&mut ctx)?;
192        let zigzag = zigzag_encode(array_primitive.as_view())?;
193        test_take_conformance(&zigzag.into_array());
194        Ok(())
195    }
196
197    #[rstest]
198    // Basic ZigZag arrays
199    #[case::zigzag_i8(zigzag_encode(PrimitiveArray::from_iter([-128i8, -1, 0, 1, 127]).as_view()).unwrap())]
200    #[case::zigzag_i16(zigzag_encode(PrimitiveArray::from_iter([-1000i16, -100, 0, 100, 1000]).as_view()).unwrap())]
201    #[case::zigzag_i32(zigzag_encode(PrimitiveArray::from_iter([-100000i32, -1000, 0, 1000, 100000]).as_view()).unwrap())]
202    #[case::zigzag_i64(zigzag_encode(PrimitiveArray::from_iter([-1000000i64, -10000, 0, 10000, 1000000]).as_view()).unwrap())]
203    // Nullable arrays
204    #[case::zigzag_nullable_i32(zigzag_encode(PrimitiveArray::from_option_iter([Some(-100i32), None, Some(0), Some(100), None]).as_view()).unwrap())]
205    #[case::zigzag_nullable_i64(zigzag_encode(PrimitiveArray::from_option_iter([Some(-1000i64), None, Some(0), Some(1000), None]).as_view()).unwrap())]
206    // Edge cases
207    #[case::zigzag_single(zigzag_encode(PrimitiveArray::from_iter([-42i32]).as_view()).unwrap())]
208    #[case::zigzag_alternating(zigzag_encode(PrimitiveArray::from_iter([-1i32, 1, -2, 2, -3, 3]).as_view()).unwrap())]
209    // Large arrays
210    #[case::zigzag_large_i32(zigzag_encode(PrimitiveArray::from_iter(-500..500).as_view()).unwrap())]
211    #[case::zigzag_large_i64(zigzag_encode(PrimitiveArray::from_iter((-1000..1000).map(|i| i as i64 * 100)).as_view()).unwrap())]
212    fn test_zigzag_consistency(#[case] array: ZigZagArray) {
213        test_array_consistency(&array.into_array());
214    }
215
216    #[rstest]
217    #[case::zigzag_i8_basic(zigzag_encode(PrimitiveArray::from_iter([-10i8, -5, 0, 5, 10]).as_view()).unwrap())]
218    #[case::zigzag_i16_basic(zigzag_encode(PrimitiveArray::from_iter([-100i16, -50, 0, 50, 100]).as_view()).unwrap())]
219    #[case::zigzag_i32_basic(zigzag_encode(PrimitiveArray::from_iter([-1000i32, -500, 0, 500, 1000]).as_view()).unwrap())]
220    #[case::zigzag_i64_basic(zigzag_encode(PrimitiveArray::from_iter([-10000i64, -5000, 0, 5000, 10000]).as_view()).unwrap())]
221    #[case::zigzag_i32_large(zigzag_encode(PrimitiveArray::from_iter((-50..50).map(|i| i * 10)).as_view()).unwrap())]
222    fn test_zigzag_binary_numeric(#[case] array: ZigZagArray) {
223        test_binary_numeric_array(array.into_array());
224    }
225}