vortex_zigzag/
compute.rs

1use vortex_array::compute::{
2    FilterKernel, FilterKernelAdapter, KernelRef, ScalarAtFn, SliceFn, TakeFn, filter, scalar_at,
3    slice, take,
4};
5use vortex_array::variants::PrimitiveArrayTrait;
6use vortex_array::vtable::ComputeVTable;
7use vortex_array::{Array, ArrayComputeImpl, ArrayRef};
8use vortex_dtype::match_each_unsigned_integer_ptype;
9use vortex_error::{VortexResult, vortex_err};
10use vortex_mask::Mask;
11use vortex_scalar::{PrimitiveScalar, Scalar};
12use zigzag::{ZigZag as ExternalZigZag, ZigZag};
13
14use crate::{ZigZagArray, ZigZagEncoding};
15
16impl ArrayComputeImpl for ZigZagArray {
17    const FILTER: Option<KernelRef> = FilterKernelAdapter(ZigZagEncoding).some();
18}
19
20impl ComputeVTable for ZigZagEncoding {
21    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
22        Some(self)
23    }
24
25    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
26        Some(self)
27    }
28
29    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
30        Some(self)
31    }
32}
33
34impl FilterKernel for ZigZagEncoding {
35    fn filter(&self, array: &ZigZagArray, mask: &Mask) -> VortexResult<ArrayRef> {
36        let encoded = filter(array.encoded(), mask)?;
37        Ok(ZigZagArray::try_new(encoded)?.into_array())
38    }
39}
40
41impl ScalarAtFn<&ZigZagArray> for ZigZagEncoding {
42    fn scalar_at(&self, array: &ZigZagArray, index: usize) -> VortexResult<Scalar> {
43        let scalar = scalar_at(array.encoded(), index)?;
44        if scalar.is_null() {
45            return Ok(scalar.reinterpret_cast(array.ptype()));
46        }
47
48        let pscalar = PrimitiveScalar::try_from(&scalar)?;
49        match_each_unsigned_integer_ptype!(pscalar.ptype(), |$P| {
50            Ok(Scalar::primitive(
51                <<$P as ZigZagEncoded>::Int>::decode(pscalar.typed_value::<$P>().ok_or_else(|| {
52                    vortex_err!(
53                        "Cannot decode provided scalar: expected {}, got ptype {}",
54                        std::any::type_name::<$P>(),
55                        pscalar.ptype()
56                    )
57                })?),
58                array.dtype().nullability(),
59            ))
60        })
61    }
62}
63
64impl SliceFn<&ZigZagArray> for ZigZagEncoding {
65    fn slice(&self, array: &ZigZagArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
66        Ok(ZigZagArray::try_new(slice(array.encoded(), start, stop)?)?.into_array())
67    }
68}
69
70impl TakeFn<&ZigZagArray> for ZigZagEncoding {
71    fn take(&self, array: &ZigZagArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
72        let encoded = take(array.encoded(), indices)?;
73        Ok(ZigZagArray::try_new(encoded)?.into_array())
74    }
75}
76
77trait ZigZagEncoded {
78    type Int: ZigZag;
79}
80
81impl ZigZagEncoded for u8 {
82    type Int = i8;
83}
84
85impl ZigZagEncoded for u16 {
86    type Int = i16;
87}
88
89impl ZigZagEncoded for u32 {
90    type Int = i32;
91}
92
93impl ZigZagEncoded for u64 {
94    type Int = i64;
95}
96
97#[cfg(test)]
98mod tests {
99    use vortex_array::arrays::{BooleanBuffer, PrimitiveArray};
100    use vortex_array::compute::{
101        SearchResult, SearchSortedSide, filter, scalar_at, search_sorted, take,
102    };
103    use vortex_array::validity::Validity;
104    use vortex_array::vtable::EncodingVTable;
105    use vortex_array::{Array, IntoArray, ToCanonical};
106    use vortex_buffer::buffer;
107    use vortex_dtype::Nullability;
108    use vortex_scalar::Scalar;
109
110    use crate::ZigZagEncoding;
111
112    #[test]
113    pub fn search_sorted_uncompressed() {
114        let zigzag = ZigZagEncoding
115            .encode(
116                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
117                None,
118            )
119            .unwrap()
120            .unwrap();
121        assert_eq!(
122            search_sorted(&zigzag, -169, SearchSortedSide::Right).unwrap(),
123            SearchResult::NotFound(1)
124        );
125    }
126
127    #[test]
128    pub fn nullable_scalar_at() {
129        let zigzag = ZigZagEncoding
130            .encode(
131                &PrimitiveArray::new(buffer![-189, -160, 1], Validity::AllValid)
132                    .to_canonical()
133                    .unwrap(),
134                None,
135            )
136            .unwrap()
137            .unwrap();
138        assert_eq!(
139            scalar_at(&zigzag, 1).unwrap(),
140            Scalar::primitive(-160, Nullability::Nullable)
141        );
142    }
143
144    #[test]
145    fn take_zigzag() {
146        let zigzag = ZigZagEncoding
147            .encode(
148                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
149                None,
150            )
151            .unwrap()
152            .unwrap();
153
154        let indices = buffer![0, 2].into_array();
155        let actual = take(&zigzag, &indices).unwrap().to_primitive().unwrap();
156        let expected = ZigZagEncoding
157            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
158            .unwrap()
159            .unwrap()
160            .to_primitive()
161            .unwrap();
162        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
163    }
164
165    #[test]
166    fn filter_zigzag() {
167        let zigzag = ZigZagEncoding
168            .encode(
169                &buffer![-189, -160, 1].into_array().to_canonical().unwrap(),
170                None,
171            )
172            .unwrap()
173            .unwrap();
174        let filter_mask = BooleanBuffer::from(vec![true, false, true]).into();
175        let actual = filter(&zigzag, &filter_mask)
176            .unwrap()
177            .to_primitive()
178            .unwrap();
179        let expected = ZigZagEncoding
180            .encode(&buffer![-189, 1].into_array().to_canonical().unwrap(), None)
181            .unwrap()
182            .unwrap()
183            .to_primitive()
184            .unwrap();
185        assert_eq!(actual.as_slice::<i32>(), expected.as_slice::<i32>());
186    }
187}