vortex_fastlanes/rle/array/
rle_compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrayref::array_mut_ref;
5use arrayref::array_ref;
6use fastlanes::RLE;
7use vortex_array::IntoArray;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::PrimitiveArray;
10use vortex_array::validity::Validity;
11use vortex_array::vtable::ValidityHelper;
12use vortex_buffer::BitBufferMut;
13use vortex_buffer::BufferMut;
14use vortex_dtype::NativePType;
15use vortex_dtype::match_each_native_ptype;
16use vortex_error::VortexResult;
17
18use crate::FL_CHUNK_SIZE;
19use crate::RLEArray;
20
21impl RLEArray {
22    /// Encodes a primitive array of unsigned integers using FastLanes RLE.
23    pub fn encode(array: &PrimitiveArray) -> VortexResult<Self> {
24        match_each_native_ptype!(array.ptype(), |T| { rle_encode_typed::<T>(array) })
25    }
26}
27
28/// Encodes a primitive array of unsigned integers using FastLanes RLE.
29///
30/// In case the input array length is % 1024 != 0, the last chunk is padded.
31fn rle_encode_typed<T>(array: &PrimitiveArray) -> VortexResult<RLEArray>
32where
33    T: NativePType + RLE,
34{
35    let values = array.as_slice::<T>();
36    let len = values.len();
37    let padded_len = len.next_multiple_of(FL_CHUNK_SIZE);
38
39    // Allocate capacity up to the next multiple of chunk size.
40    let mut values_buf = BufferMut::<T>::with_capacity(padded_len);
41    let mut indices_buf = BufferMut::<u16>::with_capacity(padded_len);
42
43    // Pre-allocate for one offset per chunk.
44    let mut values_idx_offsets = BufferMut::<u64>::with_capacity(len.div_ceil(FL_CHUNK_SIZE));
45
46    let values_uninit = values_buf.spare_capacity_mut();
47    let indices_uninit = indices_buf.spare_capacity_mut();
48    let mut value_count_acc = 0; // Chunk value count prefix sum.
49
50    let mut chunks = values.chunks_exact(FL_CHUNK_SIZE);
51
52    let mut process_chunk = |chunk_start_idx: usize, input: &[T; FL_CHUNK_SIZE]| {
53        // SAFETY: `MaybeUninit<T>` and `T` have the same layout.
54        let rle_vals: &mut [T] =
55            unsafe { std::mem::transmute(&mut values_uninit[value_count_acc..][..FL_CHUNK_SIZE]) };
56
57        // SAFETY: `MaybeUninit<u16>` and `u16` have the same layout.
58        let rle_idxs: &mut [u16] =
59            unsafe { std::mem::transmute(&mut indices_uninit[chunk_start_idx..][..FL_CHUNK_SIZE]) };
60
61        // Capture chunk start indices. This is necessary as indices
62        // returned from `T::encode` are relative to the chunk.
63        values_idx_offsets.push(value_count_acc as u64);
64
65        let value_count = T::encode(
66            input,
67            array_mut_ref![rle_vals, 0, FL_CHUNK_SIZE],
68            array_mut_ref![rle_idxs, 0, FL_CHUNK_SIZE],
69        );
70
71        value_count_acc += value_count;
72    };
73
74    for (chunk_idx, chunk_slice) in chunks.by_ref().enumerate() {
75        process_chunk(
76            chunk_idx * FL_CHUNK_SIZE,
77            array_ref![chunk_slice, 0, FL_CHUNK_SIZE],
78        );
79    }
80
81    let remainder = chunks.remainder();
82    if !remainder.is_empty() {
83        // Repeat the last value for padding to prevent
84        // accounting for an additional value change.
85        let mut padded_chunk = [values[len - 1]; FL_CHUNK_SIZE];
86        padded_chunk[..remainder.len()].copy_from_slice(remainder);
87        process_chunk((len / FL_CHUNK_SIZE) * FL_CHUNK_SIZE, &padded_chunk);
88    }
89
90    unsafe {
91        values_buf.set_len(value_count_acc);
92        indices_buf.set_len(padded_len);
93    }
94
95    RLEArray::try_new(
96        values_buf.into_array(),
97        PrimitiveArray::new(indices_buf.freeze(), padded_validity(array)).into_array(),
98        values_idx_offsets.into_array(),
99        0,
100        array.len(),
101    )
102}
103
104/// Returns validity padded to the next 1024 chunk for a given array.
105fn padded_validity(array: &PrimitiveArray) -> Validity {
106    match array.validity() {
107        Validity::NonNullable => Validity::NonNullable,
108        Validity::AllValid => Validity::AllValid,
109        Validity::AllInvalid => Validity::AllInvalid,
110        Validity::Array(validity_array) => {
111            let len = array.len();
112            let padded_len = len.next_multiple_of(FL_CHUNK_SIZE);
113
114            if len == padded_len {
115                return Validity::Array(validity_array.clone());
116            }
117
118            let mut builder = BitBufferMut::with_capacity(padded_len);
119
120            let bool_array = validity_array.to_bool();
121            builder.append_buffer(bool_array.bit_buffer());
122            builder.append_n(false, padded_len - len);
123
124            Validity::from(builder.freeze())
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use rstest::rstest;
132    use vortex_array::IntoArray;
133    use vortex_array::ToCanonical;
134    use vortex_array::assert_arrays_eq;
135    use vortex_buffer::Buffer;
136    use vortex_dtype::half::f16;
137
138    use super::*;
139
140    #[test]
141    fn test_encode_decode() {
142        // u8
143        let values_u8: Buffer<u8> = [1, 1, 2, 2, 3, 3].iter().copied().collect();
144        let array_u8 = values_u8.into_array();
145        let encoded_u8 = RLEArray::encode(&array_u8.to_primitive()).unwrap();
146        let decoded_u8 = encoded_u8.to_primitive();
147        let expected_u8 = PrimitiveArray::from_iter(vec![1u8, 1, 2, 2, 3, 3]);
148        assert_arrays_eq!(decoded_u8, expected_u8);
149
150        // u16
151        let values_u16: Buffer<u16> = [100, 100, 200, 200].iter().copied().collect();
152        let array_u16 = values_u16.into_array();
153        let encoded_u16 = RLEArray::encode(&array_u16.to_primitive()).unwrap();
154        let decoded_u16 = encoded_u16.to_primitive();
155        let expected_u16 = PrimitiveArray::from_iter(vec![100u16, 100, 200, 200]);
156        assert_arrays_eq!(decoded_u16, expected_u16);
157
158        // u64
159        let values_u64: Buffer<u64> = [1000, 1000, 2000].iter().copied().collect();
160        let array_u64 = values_u64.into_array();
161        let encoded_u64 = RLEArray::encode(&array_u64.to_primitive()).unwrap();
162        let decoded_u64 = encoded_u64.to_primitive();
163        let expected_u64 = PrimitiveArray::from_iter(vec![1000u64, 1000, 2000]);
164        assert_arrays_eq!(decoded_u64, expected_u64);
165    }
166
167    #[test]
168    fn test_length() {
169        let values: Buffer<u32> = [1, 1, 2, 2, 2, 3].iter().copied().collect();
170        let array = values.into_array();
171        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
172        assert_eq!(encoded.len(), 6);
173    }
174
175    #[test]
176    fn test_empty_length() {
177        let values: Buffer<u32> = Buffer::empty();
178        let array = values.into_array();
179        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
180
181        assert_eq!(encoded.len(), 0);
182        assert_eq!(encoded.values().len(), 0);
183    }
184
185    #[test]
186    fn test_single_value() {
187        let values: Buffer<u16> = vec![42; 2000].into_iter().collect();
188        let array = values.into_array();
189
190        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
191        assert_eq!(encoded.values().len(), 2); // 2 chunks, each storing value 42
192
193        let decoded = encoded.to_primitive(); // Verify round-trip
194        let expected = PrimitiveArray::from_iter(vec![42u16; 2000]);
195        assert_arrays_eq!(decoded, expected);
196    }
197
198    #[test]
199    fn test_all_different() {
200        let values: Buffer<u8> = (0u8..=255).collect();
201        let array = values.into_array();
202
203        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
204        assert_eq!(encoded.values().len(), 256);
205
206        let decoded = encoded.to_primitive(); // Verify round-trip
207        let expected = PrimitiveArray::from_iter((0u8..=255).collect::<Vec<_>>());
208        assert_arrays_eq!(decoded, expected);
209    }
210
211    #[test]
212    fn test_partial_last_chunk() {
213        // Test array with partial last chunk (not divisible by 1024)
214        let values: Buffer<u32> = (0..1500).map(|i| (i / 100) as u32).collect();
215        let array = values.into_array();
216
217        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
218        let decoded = encoded.to_primitive();
219
220        assert_eq!(encoded.len(), 1500);
221        assert_arrays_eq!(decoded, array);
222        // 2 chunks: 1024 + 476 elements
223        assert_eq!(encoded.values_idx_offsets().len(), 2);
224    }
225
226    #[test]
227    fn test_two_full_chunks() {
228        // Array that spans exactly 2 chunks (2048 elements)
229        let values: Buffer<u32> = (0..2048).map(|i| (i / 100) as u32).collect();
230        let array = values.into_array();
231
232        let encoded = RLEArray::encode(&array.to_primitive()).unwrap();
233        let decoded = encoded.to_primitive();
234
235        assert_eq!(encoded.len(), 2048);
236        assert_arrays_eq!(decoded, array);
237        assert_eq!(encoded.values_idx_offsets().len(), 2);
238    }
239
240    #[rstest]
241    #[case::u8((0u8..100).collect::<Buffer<u8>>())]
242    #[case::u16((0u16..2000).collect::<Buffer<u16>>())]
243    #[case::u32((0u32..2000).collect::<Buffer<u32>>())]
244    #[case::u64((0u64..2000).collect::<Buffer<u64>>())]
245    #[case::i8((-100i8..100).collect::<Buffer<i8>>())]
246    #[case::i16((-2000i16..2000).collect::<Buffer<i16>>())]
247    #[case::i32((-2000i32..2000).collect::<Buffer<i32>>())]
248    #[case::i64((-2000i64..2000).collect::<Buffer<i64>>())]
249    #[case::f16((-2000..2000).map(|i| f16::from_f32(i as f32)).collect::<Buffer<f16>>())]
250    #[case::f32((-2000..2000).map(|i| i as f32).collect::<Buffer<f32>>())]
251    #[case::f64((-2000..2000).map(|i| i as f64).collect::<Buffer<f64>>())]
252    fn test_roundtrip_primitive_types<T: NativePType>(#[case] values: Buffer<T>) {
253        let primitive = values.clone().into_array().to_primitive();
254        let result = RLEArray::encode(&primitive).unwrap();
255        let decoded = result.to_primitive();
256        let expected = PrimitiveArray::new(values, primitive.validity().clone());
257        assert_arrays_eq!(decoded, expected);
258    }
259}