Skip to main content

vortex_sequence/
compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Add;
5
6use num_traits::CheckedAdd;
7use num_traits::CheckedSub;
8use vortex_array::ArrayRef;
9use vortex_array::IntoArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::dtype::NativePType;
12use vortex_array::dtype::Nullability;
13use vortex_array::match_each_integer_ptype;
14use vortex_array::match_each_native_ptype;
15use vortex_array::scalar::PValue;
16use vortex_array::validity::Validity;
17use vortex_buffer::BufferMut;
18use vortex_buffer::trusted_len::TrustedLen;
19use vortex_error::VortexResult;
20
21use crate::SequenceArray;
22
23/// An iterator that yields `base, base + step, base + 2*step, ...` via repeated addition.
24struct SequenceIter<T> {
25    acc: T,
26    step: T,
27    remaining: usize,
28}
29
30impl<T: Copy + Add<Output = T>> Iterator for SequenceIter<T> {
31    type Item = T;
32
33    #[inline]
34    fn next(&mut self) -> Option<T> {
35        if self.remaining == 0 {
36            return None;
37        }
38        let val = self.acc;
39        self.remaining -= 1;
40        if self.remaining > 0 {
41            self.acc = self.acc + self.step;
42        }
43        Some(val)
44    }
45
46    #[inline]
47    fn size_hint(&self) -> (usize, Option<usize>) {
48        (self.remaining, Some(self.remaining))
49    }
50}
51
52// SAFETY: `size_hint` returns an exact count and `next` yields exactly that many items.
53unsafe impl<T: Copy + Add<Output = T>> TrustedLen for SequenceIter<T> {}
54
55/// Decompresses a [`SequenceArray`] into a [`PrimitiveArray`].
56#[inline]
57pub fn sequence_decompress(array: &SequenceArray) -> VortexResult<ArrayRef> {
58    fn decompress_inner<P: NativePType>(
59        base: P,
60        multiplier: P,
61        len: usize,
62        nullability: Nullability,
63    ) -> PrimitiveArray {
64        let values = BufferMut::from_trusted_len_iter(SequenceIter {
65            acc: base,
66            step: multiplier,
67            remaining: len,
68        });
69        PrimitiveArray::new(values, Validity::from(nullability))
70    }
71
72    let prim = match_each_native_ptype!(array.ptype(), |P| {
73        let base = array.base().cast::<P>()?;
74        let multiplier = array.multiplier().cast::<P>()?;
75        decompress_inner(base, multiplier, array.len(), array.dtype().nullability())
76    });
77    Ok(prim.into_array())
78}
79
80/// Encodes a primitive array into a sequence array, this is possible if:
81/// 1. The array is not empty, and contains no nulls
82/// 2. The array is not a float array. (This is due to precision issues, how it will stack well with ALP).
83/// 3. The array is representable as a sequence `A[i] = base + i * multiplier` for multiplier != 0.
84/// 4. The sequence has no deviations from the equation, this could be fixed with patches. However,
85///    we might want a different array for that since sequence provide fast access.
86pub fn sequence_encode(primitive_array: &PrimitiveArray) -> VortexResult<Option<ArrayRef>> {
87    if primitive_array.is_empty() {
88        // we cannot encode an empty array
89        return Ok(None);
90    }
91
92    if !primitive_array.all_valid()? {
93        return Ok(None);
94    }
95
96    if primitive_array.ptype().is_float() {
97        // for now, we don't handle float arrays, due to possible precision issues
98        return Ok(None);
99    }
100
101    match_each_integer_ptype!(primitive_array.ptype(), |P| {
102        encode_primitive_array(
103            primitive_array.as_slice::<P>(),
104            primitive_array.dtype().nullability(),
105        )
106    })
107}
108
109fn encode_primitive_array<P: NativePType + Into<PValue> + CheckedAdd + CheckedSub>(
110    slice: &[P],
111    nullability: Nullability,
112) -> VortexResult<Option<ArrayRef>> {
113    if slice.len() == 1 {
114        // The multiplier here can be any value, zero is chosen
115        return SequenceArray::try_new_typed(slice[0], P::zero(), nullability, 1)
116            .map(|a| Some(a.into_array()));
117    }
118    let base = slice[0];
119    let Some(multiplier) = slice[1].checked_sub(&base) else {
120        return Ok(None);
121    };
122
123    if multiplier == P::zero() {
124        return Ok(None);
125    }
126
127    if SequenceArray::try_last(base.into(), multiplier.into(), P::PTYPE, slice.len()).is_err() {
128        // If the last value is out of range, we cannot encode
129        return Ok(None);
130    }
131
132    slice
133        .windows(2)
134        .all(|w| Some(w[1]) == w[0].checked_add(&multiplier))
135        .then_some(
136            SequenceArray::try_new_typed(base, multiplier, nullability, slice.len())
137                .map(|a| a.into_array()),
138        )
139        .transpose()
140}
141
142#[cfg(test)]
143mod tests {
144    #[allow(unused_imports)]
145    use itertools::Itertools;
146    use vortex_array::ToCanonical;
147    use vortex_array::arrays::PrimitiveArray;
148    use vortex_array::assert_arrays_eq;
149
150    use crate::sequence_encode;
151
152    #[test]
153    fn test_encode_array_success() {
154        let primitive_array = PrimitiveArray::from_iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
155        let encoded = sequence_encode(&primitive_array).unwrap();
156        assert!(encoded.is_some());
157        let decoded = encoded.unwrap().to_primitive();
158        assert_arrays_eq!(decoded, primitive_array);
159    }
160
161    #[test]
162    fn test_encode_array_1_success() {
163        let primitive_array = PrimitiveArray::from_iter([0]);
164        let encoded = sequence_encode(&primitive_array).unwrap();
165        assert!(encoded.is_some());
166        let decoded = encoded.unwrap().to_primitive();
167        assert_arrays_eq!(decoded, primitive_array);
168    }
169
170    #[test]
171    fn test_encode_array_fail() {
172        let primitive_array = PrimitiveArray::from_iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0]);
173
174        let encoded = sequence_encode(&primitive_array).unwrap();
175        assert!(encoded.is_none());
176    }
177
178    #[test]
179    fn test_encode_array_fail_oob() {
180        let primitive_array = PrimitiveArray::from_iter(vec![100i8; 1000]);
181
182        let encoded = sequence_encode(&primitive_array).unwrap();
183        assert!(encoded.is_none());
184    }
185
186    #[test]
187    fn test_encode_all_u8_values() {
188        let primitive_array = PrimitiveArray::from_iter(0u8..=255);
189        let encoded = sequence_encode(&primitive_array).unwrap();
190        assert!(encoded.is_some());
191        let decoded = encoded.unwrap().to_primitive();
192        assert_arrays_eq!(decoded, primitive_array);
193    }
194}