vortex_zigzag/
compute.rs

1use vortex_array::compute::{
2    FilterKernel, FilterKernelAdapter, ScalarAtFn, SliceFn, TakeFn, filter, scalar_at, slice, take,
3};
4use vortex_array::variants::PrimitiveArrayTrait;
5use vortex_array::vtable::ComputeVTable;
6use vortex_array::{Array, ArrayRef, register_kernel};
7use vortex_dtype::match_each_unsigned_integer_ptype;
8use vortex_error::{VortexResult, vortex_err};
9use vortex_mask::Mask;
10use vortex_scalar::{PrimitiveScalar, Scalar};
11use zigzag::{ZigZag as ExternalZigZag, ZigZag};
12
13use crate::{ZigZagArray, ZigZagEncoding};
14
15impl ComputeVTable for ZigZagEncoding {
16    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
17        Some(self)
18    }
19
20    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
21        Some(self)
22    }
23
24    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
25        Some(self)
26    }
27}
28
29impl FilterKernel for ZigZagEncoding {
30    fn filter(&self, array: &ZigZagArray, mask: &Mask) -> VortexResult<ArrayRef> {
31        let encoded = filter(array.encoded(), mask)?;
32        Ok(ZigZagArray::try_new(encoded)?.into_array())
33    }
34}
35
36register_kernel!(FilterKernelAdapter(ZigZagEncoding).lift());
37
38impl ScalarAtFn<&ZigZagArray> for ZigZagEncoding {
39    fn scalar_at(&self, array: &ZigZagArray, index: usize) -> VortexResult<Scalar> {
40        let scalar = scalar_at(array.encoded(), index)?;
41        if scalar.is_null() {
42            return Ok(scalar.reinterpret_cast(array.ptype()));
43        }
44
45        let pscalar = PrimitiveScalar::try_from(&scalar)?;
46        match_each_unsigned_integer_ptype!(pscalar.ptype(), |$P| {
47            Ok(Scalar::primitive(
48                <<$P as ZigZagEncoded>::Int>::decode(pscalar.typed_value::<$P>().ok_or_else(|| {
49                    vortex_err!(
50                        "Cannot decode provided scalar: expected {}, got ptype {}",
51                        std::any::type_name::<$P>(),
52                        pscalar.ptype()
53                    )
54                })?),
55                array.dtype().nullability(),
56            ))
57        })
58    }
59}
60
61impl SliceFn<&ZigZagArray> for ZigZagEncoding {
62    fn slice(&self, array: &ZigZagArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
63        Ok(ZigZagArray::try_new(slice(array.encoded(), start, stop)?)?.into_array())
64    }
65}
66
67impl TakeFn<&ZigZagArray> for ZigZagEncoding {
68    fn take(&self, array: &ZigZagArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
69        let encoded = take(array.encoded(), indices)?;
70        Ok(ZigZagArray::try_new(encoded)?.into_array())
71    }
72}
73
74trait ZigZagEncoded {
75    type Int: ZigZag;
76}
77
78impl ZigZagEncoded for u8 {
79    type Int = i8;
80}
81
82impl ZigZagEncoded for u16 {
83    type Int = i16;
84}
85
86impl ZigZagEncoded for u32 {
87    type Int = i32;
88}
89
90impl ZigZagEncoded for u64 {
91    type Int = i64;
92}
93
94#[cfg(test)]
95mod tests {
96    use vortex_array::arrays::{BooleanBuffer, PrimitiveArray};
97    use vortex_array::compute::{
98        SearchResult, SearchSortedSide, filter, scalar_at, search_sorted, take,
99    };
100    use vortex_array::validity::Validity;
101    use vortex_array::vtable::EncodingVTable;
102    use vortex_array::{Array, IntoArray, ToCanonical};
103    use vortex_buffer::buffer;
104    use vortex_dtype::Nullability;
105    use vortex_scalar::Scalar;
106
107    use crate::ZigZagEncoding;
108
109    #[test]
110    pub fn search_sorted_uncompressed() {
111        let zigzag = ZigZagEncoding
112            .encode(
113                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
114                None,
115            )
116            .unwrap()
117            .unwrap();
118        assert_eq!(
119            search_sorted(&zigzag, -169, SearchSortedSide::Right).unwrap(),
120            SearchResult::NotFound(1)
121        );
122    }
123
124    #[test]
125    pub fn nullable_scalar_at() {
126        let zigzag = ZigZagEncoding
127            .encode(
128                &PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid)
129                    .to_canonical()
130                    .unwrap(),
131                None,
132            )
133            .unwrap()
134            .unwrap();
135        assert_eq!(
136            scalar_at(&zigzag, 1).unwrap(),
137            Scalar::primitive(-160, Nullability::Nullable)
138        );
139    }
140
141    #[test]
142    fn take_zigzag() {
143        let zigzag = ZigZagEncoding
144            .encode(
145                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
146                None,
147            )
148            .unwrap()
149            .unwrap();
150
151        let indices = buffer![0, 2].into_array();
152        let actual = take(&zigzag, &indices).unwrap().to_primitive().unwrap();
153        let expected = ZigZagEncoding
154            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
155            .unwrap()
156            .unwrap()
157            .to_primitive()
158            .unwrap();
159        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
160    }
161
162    #[test]
163    fn filter_zigzag() {
164        let zigzag = ZigZagEncoding
165            .encode(
166                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
167                None,
168            )
169            .unwrap()
170            .unwrap();
171        let filter_mask = BooleanBuffer::from(vec![true, false, true]).into();
172        let actual = filter(&zigzag, &filter_mask)
173            .unwrap()
174            .to_primitive()
175            .unwrap();
176        let expected = ZigZagEncoding
177            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
178            .unwrap()
179            .unwrap()
180            .to_primitive()
181            .unwrap();
182        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
183    }
184}