vortex_fastlanes/bitpacking/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::mem;
5use std::mem::MaybeUninit;
6
7use fastlanes::BitPacking;
8use vortex_array::Array;
9use vortex_array::ArrayRef;
10use vortex_array::IntoArray;
11use vortex_array::ToCanonical;
12use vortex_array::arrays::PrimitiveArray;
13use vortex_array::compute::TakeKernel;
14use vortex_array::compute::TakeKernelAdapter;
15use vortex_array::compute::take;
16use vortex_array::register_kernel;
17use vortex_array::validity::Validity;
18use vortex_array::vtable::ValidityHelper;
19use vortex_buffer::Buffer;
20use vortex_buffer::BufferMut;
21use vortex_dtype::IntegerPType;
22use vortex_dtype::NativePType;
23use vortex_dtype::PType;
24use vortex_dtype::match_each_integer_ptype;
25use vortex_dtype::match_each_unsigned_integer_ptype;
26use vortex_error::VortexExpect as _;
27use vortex_error::VortexResult;
28
29use super::chunked_indices;
30use crate::BitPackedArray;
31use crate::BitPackedVTable;
32use crate::bitpack_decompress;
33
34// TODO(connor): This is duplicated in `encodings/fastlanes/src/bitpacking/kernels/mod.rs`.
35/// assuming the buffer is already allocated (which will happen at most once) then unpacking
36/// all 1024 elements takes ~8.8x as long as unpacking a single element on an M2 Macbook Air.
37/// see https://github.com/vortex-data/vortex/pull/190#issue-2223752833
38pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8;
39
40impl TakeKernel for BitPackedVTable {
41    fn take(&self, array: &BitPackedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
42        // If the indices are large enough, it's faster to flatten and take the primitive array.
43        if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() {
44            return take(array.to_primitive().as_ref(), indices);
45        }
46
47        // NOTE: we use the unsigned PType because all values in the BitPackedArray must
48        //  be non-negative (pre-condition of creating the BitPackedArray).
49        let ptype: PType = PType::try_from(array.dtype())?;
50        let validity = array.validity();
51        let taken_validity = validity.take(indices)?;
52
53        let indices = indices.to_primitive();
54        let taken = match_each_unsigned_integer_ptype!(ptype.to_unsigned(), |T| {
55            match_each_integer_ptype!(indices.ptype(), |I| {
56                take_primitive::<T, I>(array, &indices, taken_validity)?
57            })
58        });
59        Ok(taken.reinterpret_cast(ptype).into_array())
60    }
61}
62
63register_kernel!(TakeKernelAdapter(BitPackedVTable).lift());
64
65fn take_primitive<T: NativePType + BitPacking, I: IntegerPType>(
66    array: &BitPackedArray,
67    indices: &PrimitiveArray,
68    taken_validity: Validity,
69) -> VortexResult<PrimitiveArray> {
70    if indices.is_empty() {
71        return Ok(PrimitiveArray::new(Buffer::<T>::empty(), taken_validity));
72    }
73
74    let offset = array.offset() as usize;
75    let bit_width = array.bit_width() as usize;
76
77    let packed = array.packed_slice::<T>();
78
79    // Group indices by 1024-element chunk, *without* allocating on the heap
80    let indices_iter = indices.as_slice::<I>().iter().map(|i| {
81        i.to_usize()
82            .vortex_expect("index must be expressible as usize")
83    });
84
85    let mut output = BufferMut::<T>::with_capacity(indices.len());
86    let mut unpacked = [const { MaybeUninit::uninit() }; 1024];
87    let chunk_len = 128 * bit_width / size_of::<T>();
88
89    chunked_indices(indices_iter, offset, |chunk_idx, indices_within_chunk| {
90        let packed = &packed[chunk_idx * chunk_len..][..chunk_len];
91
92        let mut have_unpacked = false;
93        let mut offset_chunk_iter = indices_within_chunk.chunks_exact(UNPACK_CHUNK_THRESHOLD);
94
95        // this loop only runs if we have at least UNPACK_CHUNK_THRESHOLD offsets
96        for offset_chunk in &mut offset_chunk_iter {
97            assert_eq!(offset_chunk.len(), UNPACK_CHUNK_THRESHOLD); // let compiler know slice length
98            if !have_unpacked {
99                unsafe {
100                    let dst: &mut [MaybeUninit<T>] = &mut unpacked;
101                    let dst: &mut [T] = mem::transmute(dst);
102                    BitPacking::unchecked_unpack(bit_width, packed, dst);
103                }
104                have_unpacked = true;
105            }
106
107            for &index in offset_chunk {
108                output.push(unsafe { unpacked[index].assume_init() });
109            }
110        }
111
112        // if we have a remainder (i.e., < UNPACK_CHUNK_THRESHOLD leftover offsets), we need to handle it
113        if !offset_chunk_iter.remainder().is_empty() {
114            if have_unpacked {
115                // we already bulk unpacked this chunk, so we can just push the remaining elements
116                for &index in offset_chunk_iter.remainder() {
117                    output.push(unsafe { unpacked[index].assume_init() });
118                }
119            } else {
120                // we had fewer than UNPACK_CHUNK_THRESHOLD offsets in the first place,
121                // so we need to unpack each one individually
122                for &index in offset_chunk_iter.remainder() {
123                    output.push(unsafe {
124                        bitpack_decompress::unpack_single_primitive::<T>(packed, bit_width, index)
125                    });
126                }
127            }
128        }
129    });
130
131    let mut unpatched_taken = PrimitiveArray::new(output, taken_validity);
132    // Flip back to signed type before patching.
133    if array.ptype().is_signed_int() {
134        unpatched_taken = unpatched_taken.reinterpret_cast(array.ptype());
135    }
136    if let Some(patches) = array.patches()
137        && let Some(patches) = patches.take(indices.as_ref())?
138    {
139        let cast_patches = patches.cast_values(unpatched_taken.dtype())?;
140        return Ok(unpatched_taken.patch(&cast_patches));
141    }
142
143    Ok(unpatched_taken)
144}
145
146#[cfg(test)]
147#[allow(clippy::cast_possible_truncation)]
148mod test {
149    use rand::Rng;
150    use rand::distr::Uniform;
151    use rand::rng;
152    use rstest::rstest;
153    use vortex_array::Array;
154    use vortex_array::IntoArray;
155    use vortex_array::ToCanonical;
156    use vortex_array::arrays::PrimitiveArray;
157    use vortex_array::assert_arrays_eq;
158    use vortex_array::compute::take;
159    use vortex_array::validity::Validity;
160    use vortex_buffer::Buffer;
161    use vortex_buffer::buffer;
162
163    use crate::BitPackedArray;
164    use crate::bitpacking::compute::take::take_primitive;
165
166    #[test]
167    fn take_indices() {
168        let indices = buffer![0, 125, 2047, 2049, 2151, 2790].into_array();
169
170        // Create a u8 array modulo 63.
171        let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8));
172        let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap();
173
174        let primitive_result = take(bitpacked.as_ref(), &indices).unwrap().to_primitive();
175        assert_arrays_eq!(
176            primitive_result,
177            PrimitiveArray::from_iter([0u8, 62, 31, 33, 9, 18])
178        );
179    }
180
181    #[test]
182    fn take_with_patches() {
183        let unpacked = Buffer::from_iter(0u32..1024).into_array();
184        let bitpacked = BitPackedArray::encode(&unpacked, 2).unwrap();
185
186        let indices = buffer![0, 2, 4, 6].into_array();
187
188        let primitive_result = take(bitpacked.as_ref(), indices.as_ref())
189            .unwrap()
190            .to_primitive();
191        assert_arrays_eq!(primitive_result, PrimitiveArray::from_iter([0u32, 2, 4, 6]));
192    }
193
194    #[test]
195    fn take_sliced_indices() {
196        let indices = buffer![1919, 1921].into_array();
197
198        // Create a u8 array modulo 63.
199        let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8));
200        let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap();
201        let sliced = bitpacked.slice(128..2050);
202
203        let primitive_result = take(&sliced, &indices).unwrap().to_primitive();
204        assert_arrays_eq!(primitive_result, PrimitiveArray::from_iter([31u8, 33]));
205    }
206
207    #[test]
208    #[cfg_attr(miri, ignore)] // This test is too slow on miri
209    fn take_random_indices() {
210        let num_patches: usize = 128;
211        let values = (0..u16::MAX as u32 + num_patches as u32).collect::<Buffer<_>>();
212        let uncompressed = PrimitiveArray::new(values.clone(), Validity::NonNullable);
213        let packed = BitPackedArray::encode(uncompressed.as_ref(), 16).unwrap();
214        assert!(packed.patches().is_some());
215
216        let rng = rng();
217        let range = Uniform::new(0, values.len()).unwrap();
218        let random_indices =
219            PrimitiveArray::from_iter(rng.sample_iter(range).take(10_000).map(|i| i as u32));
220        let taken = take(packed.as_ref(), random_indices.as_ref()).unwrap();
221
222        // sanity check
223        random_indices
224            .as_slice::<u32>()
225            .iter()
226            .enumerate()
227            .for_each(|(ti, i)| {
228                assert_eq!(
229                    u32::try_from(&packed.scalar_at(*i as usize)).unwrap(),
230                    values[*i as usize]
231                );
232                assert_eq!(
233                    u32::try_from(&taken.scalar_at(ti)).unwrap(),
234                    values[*i as usize]
235                );
236            });
237    }
238
239    #[test]
240    #[cfg_attr(miri, ignore)]
241    fn take_signed_with_patches() {
242        let start =
243            BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();
244
245        let taken_primitive = take_primitive::<u32, u64>(
246            &start,
247            &PrimitiveArray::from_iter([0u64, 1, 2, 3]),
248            Validity::NonNullable,
249        )
250        .unwrap();
251        assert_arrays_eq!(taken_primitive, PrimitiveArray::from_iter([1i32, 2, 3, 4]));
252    }
253
254    #[test]
255    fn take_nullable_with_nullables() {
256        let start =
257            BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();
258
259        let taken_primitive = take(
260            start.as_ref(),
261            PrimitiveArray::from_option_iter([Some(0u64), Some(1), None, Some(3)]).as_ref(),
262        )
263        .unwrap()
264        .to_primitive();
265        assert_arrays_eq!(
266            taken_primitive,
267            PrimitiveArray::from_option_iter([Some(1i32), Some(2), None, Some(4)])
268        );
269        assert_eq!(taken_primitive.invalid_count(), 1);
270    }
271
272    #[rstest]
273    #[case(BitPackedArray::encode(PrimitiveArray::from_iter((0..100).map(|i| (i % 63) as u8)).as_ref(), 6).unwrap())]
274    #[case(BitPackedArray::encode(PrimitiveArray::from_iter((0..256).map(|i| i as u32)).as_ref(), 8).unwrap())]
275    #[case(BitPackedArray::encode(buffer![1i32, 2, 3, 4, 5, 6, 7, 8].into_array().as_ref(), 3).unwrap())]
276    #[case(BitPackedArray::encode(
277        PrimitiveArray::from_option_iter([Some(10u16), None, Some(20), Some(30), None]).as_ref(),
278        5
279    ).unwrap())]
280    #[case(BitPackedArray::encode(buffer![42u32].into_array().as_ref(), 6).unwrap())]
281    #[case(BitPackedArray::encode(PrimitiveArray::from_iter((0..1024).map(|i| i as u32)).as_ref(), 8).unwrap())]
282    fn test_take_bitpacked_conformance(#[case] bitpacked: BitPackedArray) {
283        use vortex_array::compute::conformance::take::test_take_conformance;
284        test_take_conformance(bitpacked.as_ref());
285    }
286}