Skip to main content

vortex_fastlanes/rle/array/
rle_compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrayref::array_mut_ref;
5use arrayref::array_ref;
6use fastlanes::RLE;
7use vortex_array::IntoArray;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::NativeValue;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::dtype::NativePType;
12use vortex_array::match_each_native_ptype;
13use vortex_array::validity::Validity;
14use vortex_array::vtable::ValidityHelper;
15use vortex_buffer::BitBufferMut;
16use vortex_buffer::BufferMut;
17use vortex_error::VortexResult;
18
19use crate::FL_CHUNK_SIZE;
20use crate::RLEArray;
21
22impl RLEArray {
23    /// Encodes a primitive array of unsigned integers using FastLanes RLE.
24    pub fn encode(array: &PrimitiveArray) -> VortexResult<Self> {
25        match_each_native_ptype!(array.ptype(), |T| { rle_encode_typed::<T>(array) })
26    }
27}
28
29/// Encodes a primitive array of unsigned integers using FastLanes RLE.
30///
31/// In case the input array length is % 1024 != 0, the last chunk is padded.
32fn rle_encode_typed<T>(array: &PrimitiveArray) -> VortexResult<RLEArray>
33where
34    T: NativePType + RLE,
35    NativeValue<T>: RLE,
36{
37    let values = array.as_slice::<T>();
38    let len = values.len();
39    let padded_len = len.next_multiple_of(FL_CHUNK_SIZE);
40
41    // Allocate capacity up to the next multiple of chunk size.
42    let mut values_buf = BufferMut::<NativeValue<T>>::with_capacity(padded_len);
43    let mut indices_buf = BufferMut::<u16>::with_capacity(padded_len);
44
45    // Pre-allocate for one offset per chunk.
46    let mut values_idx_offsets = BufferMut::<u64>::with_capacity(len.div_ceil(FL_CHUNK_SIZE));
47
48    let values_uninit = values_buf.spare_capacity_mut();
49    let indices_uninit = indices_buf.spare_capacity_mut();
50    let mut value_count_acc = 0; // Chunk value count prefix sum.
51
52    let mut chunks = values.chunks_exact(FL_CHUNK_SIZE);
53
54    let mut process_chunk = |chunk_start_idx: usize, input: &[T; FL_CHUNK_SIZE]| {
55        // SAFETY: NativeValue is repr(transparent)
56        let input: &[NativeValue<T>; FL_CHUNK_SIZE] = unsafe { std::mem::transmute(input) };
57
58        // SAFETY: `MaybeUninit<NativeValue<T>>` and `NativeValue<T>` have the same layout.
59        let rle_vals: &mut [NativeValue<T>] =
60            unsafe { std::mem::transmute(&mut values_uninit[value_count_acc..][..FL_CHUNK_SIZE]) };
61
62        // SAFETY: `MaybeUninit<u16>` and `u16` have the same layout.
63        let rle_idxs: &mut [u16] =
64            unsafe { std::mem::transmute(&mut indices_uninit[chunk_start_idx..][..FL_CHUNK_SIZE]) };
65
66        // Capture chunk start indices. This is necessary as indices
67        // returned from `T::encode` are relative to the chunk.
68        values_idx_offsets.push(value_count_acc as u64);
69
70        let value_count = NativeValue::<T>::encode(
71            input,
72            array_mut_ref![rle_vals, 0, FL_CHUNK_SIZE],
73            array_mut_ref![rle_idxs, 0, FL_CHUNK_SIZE],
74        );
75
76        value_count_acc += value_count;
77    };
78
79    for (chunk_idx, chunk_slice) in chunks.by_ref().enumerate() {
80        process_chunk(
81            chunk_idx * FL_CHUNK_SIZE,
82            array_ref![chunk_slice, 0, FL_CHUNK_SIZE],
83        );
84    }
85
86    let remainder = chunks.remainder();
87    if !remainder.is_empty() {
88        // Repeat the last value for padding to prevent
89        // accounting for an additional value change.
90        let mut padded_chunk = [values[len - 1]; FL_CHUNK_SIZE];
91        padded_chunk[..remainder.len()].copy_from_slice(remainder);
92        process_chunk((len / FL_CHUNK_SIZE) * FL_CHUNK_SIZE, &padded_chunk);
93    }
94
95    unsafe {
96        values_buf.set_len(value_count_acc);
97        indices_buf.set_len(padded_len);
98    }
99
100    // SAFETY: NativeValue<T> is repr(transparent) to T.
101    let values_buf = unsafe { values_buf.transmute::<T>().freeze() };
102
103    RLEArray::try_new(
104        values_buf.into_array(),
105        PrimitiveArray::new(indices_buf.freeze(), padded_validity(array)).into_array(),
106        values_idx_offsets.into_array(),
107        0,
108        array.len(),
109    )
110}
111
112/// Returns validity padded to the next 1024 chunk for a given array.
113fn padded_validity(array: &PrimitiveArray) -> Validity {
114    match array.validity() {
115        Validity::NonNullable => Validity::NonNullable,
116        Validity::AllValid => Validity::AllValid,
117        Validity::AllInvalid => Validity::AllInvalid,
118        Validity::Array(validity_array) => {
119            let len = array.len();
120            let padded_len = len.next_multiple_of(FL_CHUNK_SIZE);
121
122            if len == padded_len {
123                return Validity::Array(validity_array.clone());
124            }
125
126            let mut builder = BitBufferMut::with_capacity(padded_len);
127
128            let bool_array = validity_array.to_bool();
129            builder.append_buffer(&bool_array.to_bit_buffer());
130            builder.append_n(false, padded_len - len);
131
132            Validity::from(builder.freeze())
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use rstest::rstest;
140    use vortex_array::IntoArray;
141    use vortex_array::ToCanonical;
142    use vortex_array::assert_arrays_eq;
143    use vortex_array::dtype::half::f16;
144    use vortex_buffer::Buffer;
145
146    use super::*;
147
148    #[test]
149    fn test_encode_decode() {
150        // u8
151        let values_u8: Buffer<u8> = [1, 1, 2, 2, 3, 3].iter().copied().collect();
152        let array_u8 = values_u8.into_array();
153        let encoded_u8 = RLEArray::encode(&array_u8.to_primitive()).unwrap();
154        let decoded_u8 = encoded_u8.to_primitive();
155        let expected_u8 = PrimitiveArray::from_iter(vec![1u8, 1, 2, 2, 3, 3]);
156        assert_arrays_eq!(decoded_u8, expected_u8);
157
158        // u16
159        let values_u16: Buffer<u16> = [100, 100, 200, 200].iter().copied().collect();
160        let array_u16 = values_u16.into_array();
161        let encoded_u16 = RLEArray::encode(&array_u16.to_primitive()).unwrap();
162        let decoded_u16 = encoded_u16.to_primitive();
163        let expected_u16 = PrimitiveArray::from_iter(vec![100u16, 100, 200, 200]);
164        assert_arrays_eq!(decoded_u16, expected_u16);
165
166        // u64
167        let values_u64: Buffer<u64> = [1000, 1000, 2000].iter().copied().collect();
168        let array_u64 = values_u64.into_array();
169        let encoded_u64 = RLEArray::encode(&array_u64.to_primitive()).unwrap();
170        let decoded_u64 = encoded_u64.to_primitive();
171        let expected_u64 = PrimitiveArray::from_iter(vec![1000u64, 1000, 2000]);
172        assert_arrays_eq!(decoded_u64, expected_u64);
173    }
174
175    #[test]
176    fn test_length() {
177        let values: Buffer<u32> = [1, 1, 2, 2, 2, 3].iter().copied().collect();
178        let array = values.into_array();
179        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
180        assert_eq!(encoded.len(), 6);
181    }
182
183    #[test]
184    fn test_empty_length() {
185        let values: Buffer<u32> = Buffer::empty();
186        let array = values.into_array();
187        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
188
189        assert_eq!(encoded.len(), 0);
190        assert_eq!(encoded.values().len(), 0);
191    }
192
193    #[test]
194    fn test_single_value() {
195        let values: Buffer<u16> = vec![42; 2000].into_iter().collect();
196        let array = values.into_array();
197
198        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
199        assert_eq!(encoded.values().len(), 2); // 2 chunks, each storing value 42
200
201        let decoded = encoded.to_primitive(); // Verify round-trip
202        let expected = PrimitiveArray::from_iter(vec![42u16; 2000]);
203        assert_arrays_eq!(decoded, expected);
204    }
205
206    #[test]
207    fn test_all_different() {
208        let values: Buffer<u8> = (0u8..=255).collect();
209        let array = values.into_array();
210
211        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
212        assert_eq!(encoded.values().len(), 256);
213
214        let decoded = encoded.to_primitive(); // Verify round-trip
215        let expected = PrimitiveArray::from_iter((0u8..=255).collect::<Vec<_>>());
216        assert_arrays_eq!(decoded, expected);
217    }
218
219    #[test]
220    fn test_partial_last_chunk() {
221        // Test array with partial last chunk (not divisible by 1024)
222        let values: Buffer<u32> = (0..1500).map(|i| (i / 100) as u32).collect();
223        let array = values.into_array();
224
225        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
226
227        assert_eq!(encoded.len(), 1500);
228        assert_arrays_eq!(encoded, array);
229        // 2 chunks: 1024 + 476 elements
230        assert_eq!(encoded.values_idx_offsets().len(), 2);
231    }
232
233    #[test]
234    fn test_two_full_chunks() {
235        // Array that spans exactly 2 chunks (2048 elements)
236        let values: Buffer<u32> = (0..2048).map(|i| (i / 100) as u32).collect();
237        let array = values.into_array();
238
239        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
240
241        assert_eq!(encoded.len(), 2048);
242        assert_arrays_eq!(encoded, array);
243        assert_eq!(encoded.values_idx_offsets().len(), 2);
244    }
245
246    #[rstest]
247    #[case::u8((0u8..100).collect::<Buffer<u8>>())]
248    #[case::u16((0u16..2000).collect::<Buffer<u16>>())]
249    #[case::u32((0u32..2000).collect::<Buffer<u32>>())]
250    #[case::u64((0u64..2000).collect::<Buffer<u64>>())]
251    #[case::i8((-100i8..100).collect::<Buffer<i8>>())]
252    #[case::i16((-2000i16..2000).collect::<Buffer<i16>>())]
253    #[case::i32((-2000i32..2000).collect::<Buffer<i32>>())]
254    #[case::i64((-2000i64..2000).collect::<Buffer<i64>>())]
255    #[case::f16((-2000..2000).map(|i| f16::from_f32(i as f32)).collect::<Buffer<f16>>())]
256    #[case::f32((-2000..2000).map(|i| i as f32).collect::<Buffer<f32>>())]
257    #[case::f64((-2000..2000).map(|i| i as f64).collect::<Buffer<f64>>())]
258    fn test_roundtrip_primitive_types<T: NativePType>(#[case] values: Buffer<T>) {
259        let primitive = values.clone().into_array().to_primitive();
260        let result = RLEArray::encode(&primitive).unwrap();
261        let decoded = result.to_primitive();
262        let expected = PrimitiveArray::new(values, primitive.validity().clone());
263        assert_arrays_eq!(decoded, expected);
264    }
265
266    // Regression test: RLE compression properly supports decoding pos/neg zeros
267    // See <https://github.com/vortex-data/vortex/issues/6491>
268    #[rstest]
269    #[case(vec![f16::ZERO, f16::NEG_ZERO])]
270    #[case(vec![0f32, -0f32])]
271    #[case(vec![0f64, -0f64])]
272    fn test_float_zeros<T: NativePType + RLE>(#[case] values: Vec<T>) {
273        let primitive = PrimitiveArray::from_iter(values);
274        let rle = RLEArray::encode(&primitive).unwrap();
275        let decoded = rle.to_primitive();
276        assert_arrays_eq!(primitive, decoded);
277    }
278}