vortex_runend/
compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBufferBuilder;
5use itertools::Itertools;
6use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray};
7use vortex_array::compress::downscale_integer_array;
8use vortex_array::validity::Validity;
9use vortex_array::vtable::ValidityHelper;
10use vortex_array::{ArrayRef, IntoArray, ToCanonical};
11use vortex_buffer::{Buffer, BufferMut, buffer};
12use vortex_dtype::{
13    NativePType, Nullability, match_each_native_ptype, match_each_unsigned_integer_ptype,
14};
15use vortex_error::VortexResult;
16use vortex_mask::Mask;
17use vortex_scalar::Scalar;
18
19use crate::iter::trimmed_ends_iter;
20
21/// Run-end encode a `PrimitiveArray`, returning a tuple of `(ends, values)`.
22pub fn runend_encode(array: &PrimitiveArray) -> VortexResult<(PrimitiveArray, ArrayRef)> {
23    let validity = match array.validity() {
24        Validity::NonNullable => None,
25        Validity::AllValid => None,
26        Validity::AllInvalid => {
27            // We can trivially return an all-null REE array
28            return Ok((
29                PrimitiveArray::new(buffer![array.len() as u64], Validity::NonNullable),
30                ConstantArray::new(Scalar::null(array.dtype().clone()), 1).into_array(),
31            ));
32        }
33        Validity::Array(a) => Some(a.to_bool()?.boolean_buffer().clone()),
34    };
35
36    let (ends, values) = match validity {
37        None => {
38            match_each_native_ptype!(array.ptype(), |P| {
39                let (ends, values) = runend_encode_primitive(array.as_slice::<P>());
40                (
41                    PrimitiveArray::new(ends, Validity::NonNullable),
42                    PrimitiveArray::new(values, array.dtype().nullability().into()).into_array(),
43                )
44            })
45        }
46        Some(validity) => {
47            match_each_native_ptype!(array.ptype(), |P| {
48                let (ends, values) =
49                    runend_encode_nullable_primitive(array.as_slice::<P>(), validity);
50                (
51                    PrimitiveArray::new(ends, Validity::NonNullable),
52                    values.into_array(),
53                )
54            })
55        }
56    };
57
58    let ends = downscale_integer_array(ends.to_array())?.to_primitive()?;
59
60    Ok((ends, values))
61}
62
63fn runend_encode_primitive<T: NativePType>(elements: &[T]) -> (Buffer<u64>, Buffer<T>) {
64    let mut ends = BufferMut::empty();
65    let mut values = BufferMut::empty();
66
67    if elements.is_empty() {
68        return (ends.freeze(), values.freeze());
69    }
70
71    // Run-end encode the values
72    let mut prev = elements[0];
73    let mut end = 1;
74    for &e in elements.iter().skip(1) {
75        if e != prev {
76            ends.push(end);
77            values.push(prev);
78        }
79        prev = e;
80        end += 1;
81    }
82    ends.push(end);
83    values.push(prev);
84
85    (ends.freeze(), values.freeze())
86}
87
88fn runend_encode_nullable_primitive<T: NativePType>(
89    elements: &[T],
90    element_validity: BooleanBuffer,
91) -> (Buffer<u64>, PrimitiveArray) {
92    let mut ends = BufferMut::empty();
93    let mut values = BufferMut::empty();
94    let mut validity = BooleanBufferBuilder::new(values.capacity());
95
96    if elements.is_empty() {
97        return (
98            ends.freeze(),
99            PrimitiveArray::new(
100                values,
101                Validity::Array(BoolArray::from(validity.finish()).into_array()),
102            ),
103        );
104    }
105
106    // Run-end encode the values
107    let mut prev = element_validity.value(0).then(|| elements[0]);
108    let mut end = 1;
109    for e in elements
110        .iter()
111        .zip(element_validity.iter())
112        .map(|(&e, is_valid)| is_valid.then_some(e))
113        .skip(1)
114    {
115        if e != prev {
116            ends.push(end);
117            match prev {
118                None => {
119                    validity.append(false);
120                    values.push(T::default());
121                }
122                Some(p) => {
123                    validity.append(true);
124                    values.push(p);
125                }
126            }
127        }
128        prev = e;
129        end += 1;
130    }
131    ends.push(end);
132
133    match prev {
134        None => {
135            validity.append(false);
136            values.push(T::default());
137        }
138        Some(p) => {
139            validity.append(true);
140            values.push(p);
141        }
142    }
143
144    (
145        ends.freeze(),
146        PrimitiveArray::new(values, Validity::from(validity.finish())),
147    )
148}
149
150pub fn runend_decode_primitive(
151    ends: PrimitiveArray,
152    values: PrimitiveArray,
153    offset: usize,
154    length: usize,
155) -> VortexResult<PrimitiveArray> {
156    match_each_native_ptype!(values.ptype(), |P| {
157        match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
158            runend_decode_typed_primitive(
159                trimmed_ends_iter(ends.as_slice::<E>(), offset, length),
160                values.as_slice::<P>(),
161                values.validity_mask()?,
162                values.dtype().nullability(),
163                length,
164            )
165        })
166    })
167}
168
169pub fn runend_decode_bools(
170    ends: PrimitiveArray,
171    values: BoolArray,
172    offset: usize,
173    length: usize,
174) -> VortexResult<BoolArray> {
175    match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
176        runend_decode_typed_bool(
177            trimmed_ends_iter(ends.as_slice::<E>(), offset, length),
178            values.boolean_buffer().clone(),
179            values.validity_mask()?,
180            values.dtype().nullability(),
181            length,
182        )
183    })
184}
185
186pub fn runend_decode_typed_primitive<T: NativePType>(
187    run_ends: impl Iterator<Item = usize>,
188    values: &[T],
189    values_validity: Mask,
190    values_nullability: Nullability,
191    length: usize,
192) -> VortexResult<PrimitiveArray> {
193    Ok(match values_validity {
194        Mask::AllTrue(_) => {
195            let mut decoded: BufferMut<T> = BufferMut::with_capacity(length);
196            for (end, value) in run_ends.zip_eq(values) {
197                assert!(end <= length, "Runend end must be less than overall length");
198                // SAFETY:
199                // We preallocate enough capacity because we know the total length
200                unsafe { decoded.push_n_unchecked(*value, end - decoded.len()) };
201            }
202            PrimitiveArray::new(decoded, values_nullability.into())
203        }
204        Mask::AllFalse(_) => PrimitiveArray::new(Buffer::<T>::zeroed(length), Validity::AllInvalid),
205        Mask::Values(mask) => {
206            let mut decoded = BufferMut::with_capacity(length);
207            let mut decoded_validity = BooleanBufferBuilder::new(length);
208            for (end, value) in run_ends.zip_eq(
209                values
210                    .iter()
211                    .zip(mask.boolean_buffer().iter())
212                    .map(|(&v, is_valid)| is_valid.then_some(v)),
213            ) {
214                assert!(end <= length, "Runend end must be less than overall length");
215                match value {
216                    None => {
217                        decoded_validity.append_n(end - decoded.len(), false);
218                        // SAFETY:
219                        // We preallocate enough capacity because we know the total length
220                        unsafe { decoded.push_n_unchecked(T::default(), end - decoded.len()) };
221                    }
222                    Some(value) => {
223                        decoded_validity.append_n(end - decoded.len(), true);
224                        // SAFETY:
225                        // We preallocate enough capacity because we know the total length
226                        unsafe { decoded.push_n_unchecked(value, end - decoded.len()) };
227                    }
228                }
229            }
230            PrimitiveArray::new(decoded, Validity::from(decoded_validity.finish()))
231        }
232    })
233}
234
235pub fn runend_decode_typed_bool(
236    run_ends: impl Iterator<Item = usize>,
237    values: BooleanBuffer,
238    values_validity: Mask,
239    values_nullability: Nullability,
240    length: usize,
241) -> VortexResult<BoolArray> {
242    Ok(match values_validity {
243        Mask::AllTrue(_) => {
244            let mut decoded = BooleanBufferBuilder::new(length);
245            for (end, value) in run_ends.zip_eq(values.iter()) {
246                decoded.append_n(end - decoded.len(), value);
247            }
248            BoolArray::new(decoded.finish(), values_nullability.into())
249        }
250        Mask::AllFalse(_) => BoolArray::new(BooleanBuffer::new_unset(length), Validity::AllInvalid),
251        Mask::Values(mask) => {
252            let mut decoded = BooleanBufferBuilder::new(length);
253            let mut decoded_validity = BooleanBufferBuilder::new(length);
254            for (end, value) in run_ends.zip_eq(
255                values
256                    .iter()
257                    .zip(mask.boolean_buffer().iter())
258                    .map(|(v, is_valid)| is_valid.then_some(v)),
259            ) {
260                match value {
261                    None => {
262                        decoded_validity.append_n(end - decoded.len(), false);
263                        decoded.append_n(end - decoded.len(), false);
264                    }
265                    Some(value) => {
266                        decoded_validity.append_n(end - decoded.len(), true);
267                        decoded.append_n(end - decoded.len(), value);
268                    }
269                }
270            }
271            BoolArray::new(decoded.finish(), Validity::from(decoded_validity.finish()))
272        }
273    })
274}
275
276#[cfg(test)]
277mod test {
278    use arrow_buffer::BooleanBuffer;
279    use vortex_array::ToCanonical;
280    use vortex_array::arrays::PrimitiveArray;
281    use vortex_array::validity::Validity;
282    use vortex_buffer::buffer;
283
284    use crate::compress::{runend_decode_primitive, runend_encode};
285
286    #[test]
287    fn encode() {
288        let arr = PrimitiveArray::from_iter([1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3]);
289        let (ends, values) = runend_encode(&arr).unwrap();
290        let values = values.to_primitive().unwrap();
291
292        assert_eq!(ends.as_slice::<u8>(), vec![2, 5, 10]);
293        assert_eq!(values.as_slice::<i32>(), vec![1, 2, 3]);
294    }
295
296    #[test]
297    fn encode_nullable() {
298        let arr = PrimitiveArray::new(
299            buffer![1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3],
300            Validity::from(BooleanBuffer::from(vec![
301                true, true, false, false, true, true, true, true, false, false,
302            ])),
303        );
304        let (ends, values) = runend_encode(&arr).unwrap();
305        let values = values.to_primitive().unwrap();
306
307        assert_eq!(ends.as_slice::<u8>(), vec![2, 4, 5, 8, 10]);
308        assert_eq!(values.as_slice::<i32>(), vec![1, 0, 2, 3, 0]);
309    }
310
311    #[test]
312    fn encode_all_null() {
313        let arr = PrimitiveArray::new(
314            buffer![0, 0, 0, 0, 0],
315            Validity::from(BooleanBuffer::new_unset(5)),
316        );
317        let (ends, values) = runend_encode(&arr).unwrap();
318        let values = values.to_primitive().unwrap();
319
320        assert_eq!(ends.as_slice::<u64>(), vec![5]);
321        assert_eq!(values.as_slice::<i32>(), vec![0]);
322    }
323
324    #[test]
325    fn decode() {
326        let ends = PrimitiveArray::from_iter([2u32, 5, 10]);
327        let values = PrimitiveArray::from_iter([1i32, 2, 3]);
328        let decoded = runend_decode_primitive(ends, values, 0, 10).unwrap();
329
330        assert_eq!(
331            decoded.as_slice::<i32>(),
332            vec![1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3]
333        );
334    }
335}