vortex_runend/
compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_array::arrays::{BoolArray, ConstantArray, PrimitiveArray};
6use vortex_array::validity::Validity;
7use vortex_array::vtable::ValidityHelper;
8use vortex_array::{ArrayRef, IntoArray, ToCanonical};
9use vortex_buffer::{BitBuffer, BitBufferMut, Buffer, BufferMut, buffer};
10use vortex_dtype::{
11    NativePType, Nullability, match_each_native_ptype, match_each_unsigned_integer_ptype,
12};
13use vortex_error::VortexExpect;
14use vortex_mask::Mask;
15use vortex_scalar::Scalar;
16
17use crate::iter::trimmed_ends_iter;
18
19/// Run-end encode a `PrimitiveArray`, returning a tuple of `(ends, values)`.
20pub fn runend_encode(array: &PrimitiveArray) -> (PrimitiveArray, ArrayRef) {
21    let validity = match array.validity() {
22        Validity::NonNullable => None,
23        Validity::AllValid => None,
24        Validity::AllInvalid => {
25            // We can trivially return an all-null REE array
26            return (
27                PrimitiveArray::new(buffer![array.len() as u64], Validity::NonNullable),
28                ConstantArray::new(Scalar::null(array.dtype().clone()), 1).into_array(),
29            );
30        }
31        Validity::Array(a) => Some(a.to_bool().bit_buffer().clone()),
32    };
33
34    let (ends, values) = match validity {
35        None => {
36            match_each_native_ptype!(array.ptype(), |P| {
37                let (ends, values) = runend_encode_primitive(array.as_slice::<P>());
38                (
39                    PrimitiveArray::new(ends, Validity::NonNullable),
40                    PrimitiveArray::new(values, array.dtype().nullability().into()).into_array(),
41                )
42            })
43        }
44        Some(validity) => {
45            match_each_native_ptype!(array.ptype(), |P| {
46                let (ends, values) =
47                    runend_encode_nullable_primitive(array.as_slice::<P>(), validity);
48                (
49                    PrimitiveArray::new(ends, Validity::NonNullable),
50                    values.into_array(),
51                )
52            })
53        }
54    };
55
56    let ends = ends
57        .narrow()
58        .vortex_expect("Ends must succeed downcasting")
59        .to_primitive();
60
61    (ends, values)
62}
63
64fn runend_encode_primitive<T: NativePType>(elements: &[T]) -> (Buffer<u64>, Buffer<T>) {
65    let mut ends = BufferMut::empty();
66    let mut values = BufferMut::empty();
67
68    if elements.is_empty() {
69        return (ends.freeze(), values.freeze());
70    }
71
72    // Run-end encode the values
73    let mut prev = elements[0];
74    let mut end = 1;
75    for &e in elements.iter().skip(1) {
76        if e != prev {
77            ends.push(end);
78            values.push(prev);
79        }
80        prev = e;
81        end += 1;
82    }
83    ends.push(end);
84    values.push(prev);
85
86    (ends.freeze(), values.freeze())
87}
88
89fn runend_encode_nullable_primitive<T: NativePType>(
90    elements: &[T],
91    element_validity: BitBuffer,
92) -> (Buffer<u64>, PrimitiveArray) {
93    let mut ends = BufferMut::empty();
94    let mut values = BufferMut::empty();
95    let mut validity = BitBufferMut::with_capacity(values.capacity());
96
97    if elements.is_empty() {
98        return (
99            ends.freeze(),
100            PrimitiveArray::new(
101                values,
102                Validity::Array(BoolArray::from(validity.freeze()).into_array()),
103            ),
104        );
105    }
106
107    // Run-end encode the values
108    let mut prev = element_validity.value(0).then(|| elements[0]);
109    let mut end = 1;
110    for e in elements
111        .iter()
112        .zip(element_validity.iter())
113        .map(|(&e, is_valid)| is_valid.then_some(e))
114        .skip(1)
115    {
116        if e != prev {
117            ends.push(end);
118            match prev {
119                None => {
120                    validity.append(false);
121                    values.push(T::default());
122                }
123                Some(p) => {
124                    validity.append(true);
125                    values.push(p);
126                }
127            }
128        }
129        prev = e;
130        end += 1;
131    }
132    ends.push(end);
133
134    match prev {
135        None => {
136            validity.append(false);
137            values.push(T::default());
138        }
139        Some(p) => {
140            validity.append(true);
141            values.push(p);
142        }
143    }
144
145    (
146        ends.freeze(),
147        PrimitiveArray::new(values, Validity::from(validity.freeze())),
148    )
149}
150
151pub fn runend_decode_primitive(
152    ends: PrimitiveArray,
153    values: PrimitiveArray,
154    offset: usize,
155    length: usize,
156) -> PrimitiveArray {
157    match_each_native_ptype!(values.ptype(), |P| {
158        match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
159            runend_decode_typed_primitive(
160                trimmed_ends_iter(ends.as_slice::<E>(), offset, length),
161                values.as_slice::<P>(),
162                values.validity_mask(),
163                values.dtype().nullability(),
164                length,
165            )
166        })
167    })
168}
169
170pub fn runend_decode_bools(
171    ends: PrimitiveArray,
172    values: BoolArray,
173    offset: usize,
174    length: usize,
175) -> BoolArray {
176    match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
177        runend_decode_typed_bool(
178            trimmed_ends_iter(ends.as_slice::<E>(), offset, length),
179            values.bit_buffer(),
180            values.validity_mask(),
181            values.dtype().nullability(),
182            length,
183        )
184    })
185}
186
187pub fn runend_decode_typed_primitive<T: NativePType>(
188    run_ends: impl Iterator<Item = usize>,
189    values: &[T],
190    values_validity: Mask,
191    values_nullability: Nullability,
192    length: usize,
193) -> PrimitiveArray {
194    match values_validity {
195        Mask::AllTrue(_) => {
196            let mut decoded: BufferMut<T> = BufferMut::with_capacity(length);
197            for (end, value) in run_ends.zip_eq(values) {
198                assert!(end <= length, "Runend end must be less than overall length");
199                // SAFETY:
200                // We preallocate enough capacity because we know the total length
201                unsafe { decoded.push_n_unchecked(*value, end - decoded.len()) };
202            }
203            PrimitiveArray::new(decoded, values_nullability.into())
204        }
205        Mask::AllFalse(_) => PrimitiveArray::new(Buffer::<T>::zeroed(length), Validity::AllInvalid),
206        Mask::Values(mask) => {
207            let mut decoded = BufferMut::with_capacity(length);
208            let mut decoded_validity = BitBufferMut::with_capacity(length);
209            for (end, value) in run_ends.zip_eq(
210                values
211                    .iter()
212                    .zip(mask.bit_buffer().iter())
213                    .map(|(&v, is_valid)| is_valid.then_some(v)),
214            ) {
215                assert!(end <= length, "Runend end must be less than overall length");
216                match value {
217                    None => {
218                        decoded_validity.append_n(false, end - decoded.len());
219                        // SAFETY:
220                        // We preallocate enough capacity because we know the total length
221                        unsafe { decoded.push_n_unchecked(T::default(), end - decoded.len()) };
222                    }
223                    Some(value) => {
224                        decoded_validity.append_n(true, end - decoded.len());
225                        // SAFETY:
226                        // We preallocate enough capacity because we know the total length
227                        unsafe { decoded.push_n_unchecked(value, end - decoded.len()) };
228                    }
229                }
230            }
231            PrimitiveArray::new(decoded, Validity::from(decoded_validity.freeze()))
232        }
233    }
234}
235
236pub fn runend_decode_typed_bool(
237    run_ends: impl Iterator<Item = usize>,
238    values: &BitBuffer,
239    values_validity: Mask,
240    values_nullability: Nullability,
241    length: usize,
242) -> BoolArray {
243    match values_validity {
244        Mask::AllTrue(_) => {
245            let mut decoded = BitBufferMut::with_capacity(length);
246            for (end, value) in run_ends.zip_eq(values.iter()) {
247                decoded.append_n(value, end - decoded.len());
248            }
249            BoolArray::from_bit_buffer(decoded.freeze(), values_nullability.into())
250        }
251        Mask::AllFalse(_) => {
252            BoolArray::from_bit_buffer(BitBuffer::new_unset(length), Validity::AllInvalid)
253        }
254        Mask::Values(mask) => {
255            let mut decoded = BitBufferMut::with_capacity(length);
256            let mut decoded_validity = BitBufferMut::with_capacity(length);
257            for (end, value) in run_ends.zip_eq(
258                values
259                    .iter()
260                    .zip(mask.bit_buffer().iter())
261                    .map(|(v, is_valid)| is_valid.then_some(v)),
262            ) {
263                match value {
264                    None => {
265                        decoded_validity.append_n(false, end - decoded.len());
266                        decoded.append_n(false, end - decoded.len());
267                    }
268                    Some(value) => {
269                        decoded_validity.append_n(true, end - decoded.len());
270                        decoded.append_n(value, end - decoded.len());
271                    }
272                }
273            }
274            BoolArray::from_bit_buffer(decoded.freeze(), Validity::from(decoded_validity.freeze()))
275        }
276    }
277}
278
279#[cfg(test)]
280mod test {
281    use vortex_array::arrays::PrimitiveArray;
282    use vortex_array::validity::Validity;
283    use vortex_array::{ToCanonical, assert_arrays_eq};
284    use vortex_buffer::{BitBuffer, buffer};
285
286    use crate::compress::{runend_decode_primitive, runend_encode};
287
288    #[test]
289    fn encode() {
290        let arr = PrimitiveArray::from_iter([1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3]);
291        let (ends, values) = runend_encode(&arr);
292        let values = values.to_primitive();
293
294        let expected_ends = PrimitiveArray::from_iter(vec![2u8, 5, 10]);
295        assert_arrays_eq!(ends, expected_ends);
296        let expected_values = PrimitiveArray::from_iter(vec![1i32, 2, 3]);
297        assert_arrays_eq!(values, expected_values);
298    }
299
300    #[test]
301    fn encode_nullable() {
302        let arr = PrimitiveArray::new(
303            buffer![1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3],
304            Validity::from(BitBuffer::from(vec![
305                true, true, false, false, true, true, true, true, false, false,
306            ])),
307        );
308        let (ends, values) = runend_encode(&arr);
309        let values = values.to_primitive();
310
311        let expected_ends = PrimitiveArray::from_iter(vec![2u8, 4, 5, 8, 10]);
312        assert_arrays_eq!(ends, expected_ends);
313        let expected_values =
314            PrimitiveArray::from_option_iter(vec![Some(1i32), None, Some(2), Some(3), None]);
315        assert_arrays_eq!(values, expected_values);
316    }
317
318    #[test]
319    fn encode_all_null() {
320        let arr = PrimitiveArray::new(
321            buffer![0, 0, 0, 0, 0],
322            Validity::from(BitBuffer::new_unset(5)),
323        );
324        let (ends, values) = runend_encode(&arr);
325        let values = values.to_primitive();
326
327        let expected_ends = PrimitiveArray::from_iter(vec![5u64]);
328        assert_arrays_eq!(ends, expected_ends);
329        let expected_values = PrimitiveArray::from_option_iter(vec![Option::<i32>::None]);
330        assert_arrays_eq!(values, expected_values);
331    }
332
333    #[test]
334    fn decode() {
335        let ends = PrimitiveArray::from_iter([2u32, 5, 10]);
336        let values = PrimitiveArray::from_iter([1i32, 2, 3]);
337        let decoded = runend_decode_primitive(ends, values, 0, 10);
338
339        let expected = PrimitiveArray::from_iter(vec![1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3]);
340        assert_arrays_eq!(decoded, expected);
341    }
342}