Skip to main content

vortex_runend/
compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_array::ArrayRef;
6use vortex_array::IntoArray;
7use vortex_array::ToCanonical;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::expr::stats::Precision;
12use vortex_array::expr::stats::Stat;
13use vortex_array::scalar::Scalar;
14use vortex_array::validity::Validity;
15use vortex_array::vtable::ValidityHelper;
16use vortex_buffer::BitBuffer;
17use vortex_buffer::BitBufferMut;
18use vortex_buffer::Buffer;
19use vortex_buffer::BufferMut;
20use vortex_buffer::buffer;
21use vortex_dtype::NativePType;
22use vortex_dtype::Nullability;
23use vortex_dtype::match_each_native_ptype;
24use vortex_dtype::match_each_unsigned_integer_ptype;
25use vortex_error::VortexExpect;
26use vortex_error::VortexResult;
27use vortex_mask::Mask;
28
29use crate::iter::trimmed_ends_iter;
30
31/// Run-end encode a `PrimitiveArray`, returning a tuple of `(ends, values)`.
32pub fn runend_encode(array: &PrimitiveArray) -> (PrimitiveArray, ArrayRef) {
33    let validity = match array.validity() {
34        Validity::NonNullable => None,
35        Validity::AllValid => None,
36        Validity::AllInvalid => {
37            // We can trivially return an all-null REE array
38            let ends = PrimitiveArray::new(buffer![array.len() as u64], Validity::NonNullable);
39            ends.statistics()
40                .set(Stat::IsStrictSorted, Precision::Exact(true.into()));
41            return (
42                ends,
43                ConstantArray::new(Scalar::null(array.dtype().clone()), 1).into_array(),
44            );
45        }
46        Validity::Array(a) => Some(a.to_bool().to_bit_buffer()),
47    };
48
49    let (ends, values) = match validity {
50        None => {
51            match_each_native_ptype!(array.ptype(), |P| {
52                let (ends, values) = runend_encode_primitive(array.as_slice::<P>());
53                (
54                    PrimitiveArray::new(ends, Validity::NonNullable),
55                    PrimitiveArray::new(values, array.dtype().nullability().into()).into_array(),
56                )
57            })
58        }
59        Some(validity) => {
60            match_each_native_ptype!(array.ptype(), |P| {
61                let (ends, values) =
62                    runend_encode_nullable_primitive(array.as_slice::<P>(), validity);
63                (
64                    PrimitiveArray::new(ends, Validity::NonNullable),
65                    values.into_array(),
66                )
67            })
68        }
69    };
70
71    let ends = ends
72        .narrow()
73        .vortex_expect("Ends must succeed downcasting")
74        .to_primitive();
75
76    ends.statistics()
77        .set(Stat::IsStrictSorted, Precision::Exact(true.into()));
78
79    (ends, values)
80}
81
82fn runend_encode_primitive<T: NativePType>(elements: &[T]) -> (Buffer<u64>, Buffer<T>) {
83    let mut ends = BufferMut::empty();
84    let mut values = BufferMut::empty();
85
86    if elements.is_empty() {
87        return (ends.freeze(), values.freeze());
88    }
89
90    // Run-end encode the values
91    let mut prev = elements[0];
92    let mut end = 1;
93    for &e in elements.iter().skip(1) {
94        if e != prev {
95            ends.push(end);
96            values.push(prev);
97        }
98        prev = e;
99        end += 1;
100    }
101    ends.push(end);
102    values.push(prev);
103
104    (ends.freeze(), values.freeze())
105}
106
107fn runend_encode_nullable_primitive<T: NativePType>(
108    elements: &[T],
109    element_validity: BitBuffer,
110) -> (Buffer<u64>, PrimitiveArray) {
111    let mut ends = BufferMut::empty();
112    let mut values = BufferMut::empty();
113    let mut validity = BitBufferMut::with_capacity(values.capacity());
114
115    if elements.is_empty() {
116        return (
117            ends.freeze(),
118            PrimitiveArray::new(
119                values,
120                Validity::Array(BoolArray::from(validity.freeze()).into_array()),
121            ),
122        );
123    }
124
125    // Run-end encode the values
126    let mut prev = element_validity.value(0).then(|| elements[0]);
127    let mut end = 1;
128    for e in elements
129        .iter()
130        .zip(element_validity.iter())
131        .map(|(&e, is_valid)| is_valid.then_some(e))
132        .skip(1)
133    {
134        if e != prev {
135            ends.push(end);
136            match prev {
137                None => {
138                    validity.append(false);
139                    values.push(T::default());
140                }
141                Some(p) => {
142                    validity.append(true);
143                    values.push(p);
144                }
145            }
146        }
147        prev = e;
148        end += 1;
149    }
150    ends.push(end);
151
152    match prev {
153        None => {
154            validity.append(false);
155            values.push(T::default());
156        }
157        Some(p) => {
158            validity.append(true);
159            values.push(p);
160        }
161    }
162
163    (
164        ends.freeze(),
165        PrimitiveArray::new(values, Validity::from(validity.freeze())),
166    )
167}
168
169pub fn runend_decode_primitive(
170    ends: PrimitiveArray,
171    values: PrimitiveArray,
172    offset: usize,
173    length: usize,
174) -> VortexResult<PrimitiveArray> {
175    let validity_mask = values.validity_mask()?;
176    Ok(match_each_native_ptype!(values.ptype(), |P| {
177        match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
178            runend_decode_typed_primitive(
179                trimmed_ends_iter(ends.as_slice::<E>(), offset, length),
180                values.as_slice::<P>(),
181                validity_mask,
182                values.dtype().nullability(),
183                length,
184            )
185        })
186    }))
187}
188
189pub fn runend_decode_bools(
190    ends: PrimitiveArray,
191    values: BoolArray,
192    offset: usize,
193    length: usize,
194) -> VortexResult<BoolArray> {
195    let validity_mask = values.validity_mask()?;
196    Ok(match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
197        runend_decode_typed_bool(
198            trimmed_ends_iter(ends.as_slice::<E>(), offset, length),
199            &values.to_bit_buffer(),
200            validity_mask,
201            values.dtype().nullability(),
202            length,
203        )
204    }))
205}
206
207pub fn runend_decode_typed_primitive<T: NativePType>(
208    run_ends: impl Iterator<Item = usize>,
209    values: &[T],
210    values_validity: Mask,
211    values_nullability: Nullability,
212    length: usize,
213) -> PrimitiveArray {
214    match values_validity {
215        Mask::AllTrue(_) => {
216            let mut decoded: BufferMut<T> = BufferMut::with_capacity(length);
217            for (end, value) in run_ends.zip_eq(values) {
218                assert!(
219                    end >= decoded.len(),
220                    "Runend ends must be monotonic, got {end} after {}",
221                    decoded.len()
222                );
223                assert!(end <= length, "Runend end must be less than overall length");
224                // SAFETY:
225                // We preallocate enough capacity because we know the total length
226                unsafe { decoded.push_n_unchecked(*value, end - decoded.len()) };
227            }
228            PrimitiveArray::new(decoded, values_nullability.into())
229        }
230        Mask::AllFalse(_) => PrimitiveArray::new(Buffer::<T>::zeroed(length), Validity::AllInvalid),
231        Mask::Values(mask) => {
232            let mut decoded = BufferMut::with_capacity(length);
233            let mut decoded_validity = BitBufferMut::with_capacity(length);
234            for (end, value) in run_ends.zip_eq(
235                values
236                    .iter()
237                    .zip(mask.bit_buffer().iter())
238                    .map(|(&v, is_valid)| is_valid.then_some(v)),
239            ) {
240                assert!(
241                    end >= decoded.len(),
242                    "Runend ends must be monotonic, got {end} after {}",
243                    decoded.len()
244                );
245                assert!(end <= length, "Runend end must be less than overall length");
246                match value {
247                    None => {
248                        decoded_validity.append_n(false, end - decoded.len());
249                        // SAFETY:
250                        // We preallocate enough capacity because we know the total length
251                        unsafe { decoded.push_n_unchecked(T::default(), end - decoded.len()) };
252                    }
253                    Some(value) => {
254                        decoded_validity.append_n(true, end - decoded.len());
255                        // SAFETY:
256                        // We preallocate enough capacity because we know the total length
257                        unsafe { decoded.push_n_unchecked(value, end - decoded.len()) };
258                    }
259                }
260            }
261            PrimitiveArray::new(decoded, Validity::from(decoded_validity.freeze()))
262        }
263    }
264}
265
266pub fn runend_decode_typed_bool(
267    run_ends: impl Iterator<Item = usize>,
268    values: &BitBuffer,
269    values_validity: Mask,
270    values_nullability: Nullability,
271    length: usize,
272) -> BoolArray {
273    match values_validity {
274        Mask::AllTrue(_) => {
275            let mut decoded = BitBufferMut::with_capacity(length);
276            for (end, value) in run_ends.zip_eq(values.iter()) {
277                decoded.append_n(value, end - decoded.len());
278            }
279            BoolArray::new(decoded.freeze(), values_nullability.into())
280        }
281        Mask::AllFalse(_) => BoolArray::new(BitBuffer::new_unset(length), Validity::AllInvalid),
282        Mask::Values(mask) => {
283            let mut decoded = BitBufferMut::with_capacity(length);
284            let mut decoded_validity = BitBufferMut::with_capacity(length);
285            for (end, value) in run_ends.zip_eq(
286                values
287                    .iter()
288                    .zip(mask.bit_buffer().iter())
289                    .map(|(v, is_valid)| is_valid.then_some(v)),
290            ) {
291                match value {
292                    None => {
293                        decoded_validity.append_n(false, end - decoded.len());
294                        decoded.append_n(false, end - decoded.len());
295                    }
296                    Some(value) => {
297                        decoded_validity.append_n(true, end - decoded.len());
298                        decoded.append_n(value, end - decoded.len());
299                    }
300                }
301            }
302            BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze()))
303        }
304    }
305}
306
307#[cfg(test)]
308mod test {
309    use vortex_array::ToCanonical;
310    use vortex_array::arrays::PrimitiveArray;
311    use vortex_array::assert_arrays_eq;
312    use vortex_array::validity::Validity;
313    use vortex_buffer::BitBuffer;
314    use vortex_buffer::buffer;
315    use vortex_error::VortexResult;
316
317    use crate::compress::runend_decode_primitive;
318    use crate::compress::runend_encode;
319
320    #[test]
321    fn encode() {
322        let arr = PrimitiveArray::from_iter([1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3]);
323        let (ends, values) = runend_encode(&arr);
324        let values = values.to_primitive();
325
326        let expected_ends = PrimitiveArray::from_iter(vec![2u8, 5, 10]);
327        assert_arrays_eq!(ends, expected_ends);
328        let expected_values = PrimitiveArray::from_iter(vec![1i32, 2, 3]);
329        assert_arrays_eq!(values, expected_values);
330    }
331
332    #[test]
333    fn encode_nullable() {
334        let arr = PrimitiveArray::new(
335            buffer![1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3],
336            Validity::from(BitBuffer::from(vec![
337                true, true, false, false, true, true, true, true, false, false,
338            ])),
339        );
340        let (ends, values) = runend_encode(&arr);
341        let values = values.to_primitive();
342
343        let expected_ends = PrimitiveArray::from_iter(vec![2u8, 4, 5, 8, 10]);
344        assert_arrays_eq!(ends, expected_ends);
345        let expected_values =
346            PrimitiveArray::from_option_iter(vec![Some(1i32), None, Some(2), Some(3), None]);
347        assert_arrays_eq!(values, expected_values);
348    }
349
350    #[test]
351    fn encode_all_null() {
352        let arr = PrimitiveArray::new(
353            buffer![0, 0, 0, 0, 0],
354            Validity::from(BitBuffer::new_unset(5)),
355        );
356        let (ends, values) = runend_encode(&arr);
357        let values = values.to_primitive();
358
359        let expected_ends = PrimitiveArray::from_iter(vec![5u64]);
360        assert_arrays_eq!(ends, expected_ends);
361        let expected_values = PrimitiveArray::from_option_iter(vec![Option::<i32>::None]);
362        assert_arrays_eq!(values, expected_values);
363    }
364
365    #[test]
366    fn decode() -> VortexResult<()> {
367        let ends = PrimitiveArray::from_iter([2u32, 5, 10]);
368        let values = PrimitiveArray::from_iter([1i32, 2, 3]);
369        let decoded = runend_decode_primitive(ends, values, 0, 10)?;
370
371        let expected = PrimitiveArray::from_iter(vec![1i32, 1, 2, 2, 2, 3, 3, 3, 3, 3]);
372        assert_arrays_eq!(decoded, expected);
373        Ok(())
374    }
375}