vortex_fastlanes/bitpacking/compute/
take.rs1use std::mem;
5use std::mem::MaybeUninit;
6
7use fastlanes::BitPacking;
8use vortex_array::Array;
9use vortex_array::ArrayRef;
10use vortex_array::ExecutionCtx;
11use vortex_array::IntoArray;
12use vortex_array::ToCanonical;
13use vortex_array::arrays::PrimitiveArray;
14use vortex_array::arrays::TakeExecute;
15use vortex_array::dtype::IntegerPType;
16use vortex_array::dtype::NativePType;
17use vortex_array::dtype::PType;
18use vortex_array::match_each_integer_ptype;
19use vortex_array::match_each_unsigned_integer_ptype;
20use vortex_array::validity::Validity;
21use vortex_array::vtable::ValidityHelper;
22use vortex_buffer::Buffer;
23use vortex_buffer::BufferMut;
24use vortex_error::VortexExpect as _;
25use vortex_error::VortexResult;
26
27use super::chunked_indices;
28use crate::BitPackedArray;
29use crate::BitPackedVTable;
30use crate::bitpack_decompress;
31
32pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8;
37
38impl TakeExecute for BitPackedVTable {
39 fn take(
40 array: &BitPackedArray,
41 indices: &dyn Array,
42 _ctx: &mut ExecutionCtx,
43 ) -> VortexResult<Option<ArrayRef>> {
44 if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() {
46 return array.to_primitive().take(indices.to_array()).map(Some);
47 }
48
49 let ptype: PType = PType::try_from(array.dtype())?;
52 let validity = array.validity();
53 let taken_validity = validity.take(indices)?;
54
55 let indices = indices.to_primitive();
56 let taken = match_each_unsigned_integer_ptype!(ptype.to_unsigned(), |T| {
57 match_each_integer_ptype!(indices.ptype(), |I| {
58 take_primitive::<T, I>(array, &indices, taken_validity)?
59 })
60 });
61 Ok(Some(taken.reinterpret_cast(ptype).into_array()))
62 }
63}
64
65fn take_primitive<T: NativePType + BitPacking, I: IntegerPType>(
66 array: &BitPackedArray,
67 indices: &PrimitiveArray,
68 taken_validity: Validity,
69) -> VortexResult<PrimitiveArray> {
70 if indices.is_empty() {
71 return Ok(PrimitiveArray::new(Buffer::<T>::empty(), taken_validity));
72 }
73
74 let offset = array.offset() as usize;
75 let bit_width = array.bit_width() as usize;
76
77 let packed = array.packed_slice::<T>();
78
79 let indices_iter = indices.as_slice::<I>().iter().map(|i| {
81 i.to_usize()
82 .vortex_expect("index must be expressible as usize")
83 });
84
85 let mut output = BufferMut::<T>::with_capacity(indices.len());
86 let mut unpacked = [const { MaybeUninit::uninit() }; 1024];
87 let chunk_len = 128 * bit_width / size_of::<T>();
88
89 chunked_indices(indices_iter, offset, |chunk_idx, indices_within_chunk| {
90 let packed = &packed[chunk_idx * chunk_len..][..chunk_len];
91
92 let mut have_unpacked = false;
93 let mut offset_chunk_iter = indices_within_chunk.chunks_exact(UNPACK_CHUNK_THRESHOLD);
94
95 for offset_chunk in &mut offset_chunk_iter {
97 assert_eq!(offset_chunk.len(), UNPACK_CHUNK_THRESHOLD); if !have_unpacked {
99 unsafe {
100 let dst: &mut [MaybeUninit<T>] = &mut unpacked;
101 let dst: &mut [T] = mem::transmute(dst);
102 BitPacking::unchecked_unpack(bit_width, packed, dst);
103 }
104 have_unpacked = true;
105 }
106
107 for &index in offset_chunk {
108 output.push(unsafe { unpacked[index].assume_init() });
109 }
110 }
111
112 if !offset_chunk_iter.remainder().is_empty() {
114 if have_unpacked {
115 for &index in offset_chunk_iter.remainder() {
117 output.push(unsafe { unpacked[index].assume_init() });
118 }
119 } else {
120 for &index in offset_chunk_iter.remainder() {
123 output.push(unsafe {
124 bitpack_decompress::unpack_single_primitive::<T>(packed, bit_width, index)
125 });
126 }
127 }
128 }
129 });
130
131 let mut unpatched_taken = PrimitiveArray::new(output, taken_validity);
132 if array.ptype().is_signed_int() {
134 unpatched_taken = unpatched_taken.reinterpret_cast(array.ptype());
135 }
136 if let Some(patches) = array.patches()
137 && let Some(patches) = patches.take(indices.as_ref())?
138 {
139 let cast_patches = patches.cast_values(unpatched_taken.dtype())?;
140 return unpatched_taken.patch(&cast_patches);
141 }
142
143 Ok(unpatched_taken)
144}
145
146#[cfg(test)]
147#[allow(clippy::cast_possible_truncation)]
148mod test {
149 use rand::Rng;
150 use rand::distr::Uniform;
151 use rand::rng;
152 use rstest::rstest;
153 use vortex_array::Array;
154 use vortex_array::IntoArray;
155 use vortex_array::ToCanonical;
156 use vortex_array::arrays::PrimitiveArray;
157 use vortex_array::assert_arrays_eq;
158 use vortex_array::validity::Validity;
159 use vortex_buffer::Buffer;
160 use vortex_buffer::buffer;
161
162 use crate::BitPackedArray;
163 use crate::bitpacking::compute::take::take_primitive;
164
165 #[test]
166 fn take_indices() {
167 let indices = buffer![0, 125, 2047, 2049, 2151, 2790].into_array();
168
169 let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8));
171 let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap();
172
173 let primitive_result = bitpacked.take(indices.to_array()).unwrap();
174 assert_arrays_eq!(
175 primitive_result,
176 PrimitiveArray::from_iter([0u8, 62, 31, 33, 9, 18])
177 );
178 }
179
180 #[test]
181 fn take_with_patches() {
182 let unpacked = Buffer::from_iter(0u32..1024).into_array();
183 let bitpacked = BitPackedArray::encode(&unpacked, 2).unwrap();
184
185 let indices = buffer![0, 2, 4, 6].into_array();
186
187 let primitive_result = bitpacked.take(indices.to_array()).unwrap();
188 assert_arrays_eq!(primitive_result, PrimitiveArray::from_iter([0u32, 2, 4, 6]));
189 }
190
191 #[test]
192 fn take_sliced_indices() {
193 let indices = buffer![1919, 1921].into_array();
194
195 let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8));
197 let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap();
198 let sliced = bitpacked.slice(128..2050).unwrap();
199
200 let primitive_result = sliced.take(indices.to_array()).unwrap();
201 assert_arrays_eq!(primitive_result, PrimitiveArray::from_iter([31u8, 33]));
202 }
203
204 #[test]
205 #[cfg_attr(miri, ignore)] fn take_random_indices() {
207 let num_patches: usize = 128;
208 let values = (0..u16::MAX as u32 + num_patches as u32).collect::<Buffer<_>>();
209 let uncompressed = PrimitiveArray::new(values.clone(), Validity::NonNullable);
210 let packed = BitPackedArray::encode(uncompressed.as_ref(), 16).unwrap();
211 assert!(packed.patches().is_some());
212
213 let rng = rng();
214 let range = Uniform::new(0, values.len()).unwrap();
215 let random_indices =
216 PrimitiveArray::from_iter(rng.sample_iter(range).take(10_000).map(|i| i as u32));
217 let taken = packed.take(random_indices.to_array()).unwrap();
218
219 random_indices
221 .as_slice::<u32>()
222 .iter()
223 .enumerate()
224 .for_each(|(ti, i)| {
225 assert_eq!(
226 u32::try_from(&packed.scalar_at(*i as usize).unwrap()).unwrap(),
227 values[*i as usize]
228 );
229 assert_eq!(
230 u32::try_from(&taken.scalar_at(ti).unwrap()).unwrap(),
231 values[*i as usize]
232 );
233 });
234 }
235
236 #[test]
237 #[cfg_attr(miri, ignore)]
238 fn take_signed_with_patches() {
239 let start =
240 BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();
241
242 let taken_primitive = take_primitive::<u32, u64>(
243 &start,
244 &PrimitiveArray::from_iter([0u64, 1, 2, 3]),
245 Validity::NonNullable,
246 )
247 .unwrap();
248 assert_arrays_eq!(taken_primitive, PrimitiveArray::from_iter([1i32, 2, 3, 4]));
249 }
250
251 #[test]
252 fn take_nullable_with_nullables() {
253 let start =
254 BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();
255
256 let taken_primitive = start
257 .take(PrimitiveArray::from_option_iter([Some(0u64), Some(1), None, Some(3)]).to_array())
258 .unwrap();
259 assert_arrays_eq!(
260 taken_primitive,
261 PrimitiveArray::from_option_iter([Some(1i32), Some(2), None, Some(4)])
262 );
263 assert_eq!(taken_primitive.to_primitive().invalid_count().unwrap(), 1);
264 }
265
266 #[rstest]
267 #[case(BitPackedArray::encode(PrimitiveArray::from_iter((0..100).map(|i| (i % 63) as u8)).as_ref(), 6).unwrap())]
268 #[case(BitPackedArray::encode(PrimitiveArray::from_iter((0..256).map(|i| i as u32)).as_ref(), 8).unwrap())]
269 #[case(BitPackedArray::encode(buffer![1i32, 2, 3, 4, 5, 6, 7, 8].into_array().as_ref(), 3).unwrap())]
270 #[case(BitPackedArray::encode(
271 PrimitiveArray::from_option_iter([Some(10u16), None, Some(20), Some(30), None]).as_ref(),
272 5
273 ).unwrap())]
274 #[case(BitPackedArray::encode(buffer![42u32].into_array().as_ref(), 6).unwrap())]
275 #[case(BitPackedArray::encode(PrimitiveArray::from_iter((0..1024).map(|i| i as u32)).as_ref(), 8).unwrap())]
276 fn test_take_bitpacked_conformance(#[case] bitpacked: BitPackedArray) {
277 use vortex_array::compute::conformance::take::test_take_conformance;
278 test_take_conformance(bitpacked.as_ref());
279 }
280}