1use std::simd;
2
3use num_traits::AsPrimitive;
4use simd::num::SimdUint;
5use vortex_buffer::{Alignment, Buffer, BufferMut};
6use vortex_dtype::{
7 NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
8 match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
9};
10use vortex_error::{VortexResult, vortex_err};
11use vortex_mask::Mask;
12
13use crate::arrays::PrimitiveEncoding;
14use crate::arrays::primitive::PrimitiveArray;
15use crate::builders::{ArrayBuilder, PrimitiveBuilder};
16use crate::compute::TakeFn;
17use crate::variants::PrimitiveArrayTrait;
18use crate::{Array, ArrayRef, ToCanonical};
19
20impl TakeFn<&PrimitiveArray> for PrimitiveEncoding {
21 #[allow(clippy::cognitive_complexity)]
22 fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23 let indices = indices.to_primitive()?;
24 let validity = array.validity().take(&indices)?;
25
26 if array.ptype() != PType::F16
27 && indices.dtype().is_unsigned_int()
28 && indices.all_valid()?
29 && array.all_valid()?
30 {
31 match_each_unsigned_integer_ptype!(indices.ptype(), |$C| {
33 match_each_native_simd_ptype!(array.ptype(), |$V| {
34 let decoded = take_primitive_simd::<$C, $V, 64>(
37 indices.as_slice(),
38 array.as_slice(),
39 array.dtype().nullability() | indices.dtype().nullability(),
40 );
41
42 return Ok(decoded.into_array()) as VortexResult<ArrayRef>;
43 })
44 });
45 }
46
47 match_each_native_ptype!(array.ptype(), |$T| {
48 match_each_integer_ptype!(indices.ptype(), |$I| {
49 let values = take_primitive(array.as_slice::<$T>(), indices.as_slice::<$I>());
50 Ok(PrimitiveArray::new(values, validity).into_array())
51 })
52 })
53 }
54
55 fn take_into(
56 &self,
57 array: &PrimitiveArray,
58 indices: &dyn Array,
59 builder: &mut dyn ArrayBuilder,
60 ) -> VortexResult<()> {
61 let indices = indices.to_primitive()?;
62 let validity = array.validity().take(&indices)?;
64 let mask = validity.to_logical(indices.len())?;
65
66 match_each_native_ptype!(array.ptype(), |$T| {
67 match_each_integer_ptype!(indices.ptype(), |$I| {
68 take_into_impl::<$T, $I>(array, &indices, mask, builder)
69 })
70 })
71 }
72}
73
74fn take_into_impl<T: NativePType, I: NativePType + AsPrimitive<usize>>(
75 array: &PrimitiveArray,
76 indices: &PrimitiveArray,
77 mask: Mask,
78 builder: &mut dyn ArrayBuilder,
79) -> VortexResult<()> {
80 assert_eq!(indices.len(), mask.len());
81
82 let array = array.as_slice::<T>();
83 let indices = indices.as_slice::<I>();
84 let builder = builder
85 .as_any_mut()
86 .downcast_mut::<PrimitiveBuilder<T>>()
87 .ok_or_else(|| {
88 vortex_err!(
89 "Failed to downcast builder to PrimitiveBuilder<{}>",
90 T::PTYPE
91 )
92 })?;
93 builder.extend_with_iterator(indices.iter().map(|idx| array[idx.as_()]), mask);
94 Ok(())
95}
96
97fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
98 array: &[T],
99 indices: &[I],
100) -> Buffer<T> {
101 indices.iter().map(|idx| array[idx.as_()]).collect()
102}
103
104fn take_primitive_simd<I, V, const LANE_COUNT: usize>(
120 indices: &[I],
121 values: &[V],
122 nullability: Nullability,
123) -> PrimitiveArray
124where
125 I: simd::SimdElement + AsPrimitive<usize>,
126 V: simd::SimdElement + NativePType,
127 simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
128 simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
129{
130 let indices_len = indices.len();
131
132 let mut buffer = BufferMut::<V>::with_capacity_aligned(
133 indices_len,
134 Alignment::of::<simd::Simd<V, LANE_COUNT>>(),
135 );
136
137 let buf_slice = buffer.spare_capacity_mut();
138
139 for chunk_idx in 0..(indices_len / LANE_COUNT) {
140 let offset = chunk_idx * LANE_COUNT;
141 let mask = simd::Mask::from_bitmask(u64::MAX);
142 let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
143
144 unsafe {
145 let selection = simd::Simd::gather_select_unchecked(
146 values,
147 mask,
148 codes_chunk.cast::<usize>(),
149 simd::Simd::<V, LANE_COUNT>::default(),
150 );
151
152 selection.store_select_ptr(buf_slice.as_mut_ptr().add(offset) as *mut V, mask.cast());
153 }
154 }
155
156 for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
157 unsafe {
158 buf_slice
159 .get_unchecked_mut(idx)
160 .write(values[indices[idx].as_()]);
161 }
162 }
163
164 unsafe {
165 buffer.set_len(indices_len);
166 }
167
168 PrimitiveArray::new(buffer.freeze(), nullability.into())
169}
170
171#[cfg(test)]
172mod test {
173 use vortex_buffer::buffer;
174 use vortex_dtype::Nullability;
175 use vortex_scalar::Scalar;
176
177 use crate::array::Array;
178 use crate::arrays::primitive::compute::take::take_primitive;
179 use crate::arrays::{BoolArray, PrimitiveArray};
180 use crate::builders::{ArrayBuilder as _, PrimitiveBuilder};
181 use crate::compute::{scalar_at, take, take_into};
182 use crate::validity::Validity;
183
184 #[test]
185 fn test_take() {
186 let a = vec![1i32, 2, 3, 4, 5];
187 let result = take_primitive(&a, &[0, 0, 4, 2]);
188 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
189 }
190
191 #[test]
192 fn test_take_with_null_indices() {
193 let values = PrimitiveArray::new(
194 buffer![1i32, 2, 3, 4, 5],
195 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
196 );
197 let indices = PrimitiveArray::new(
198 buffer![0, 3, 4],
199 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
200 );
201 let actual = take(&values, &indices).unwrap();
202 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
203 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<i32>());
205 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
207 }
208
209 #[test]
210 fn test_take_into() {
211 let values = PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable);
212 let all_valid_indices = PrimitiveArray::new(
213 buffer![0, 3, 4],
214 Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
215 );
216 let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
217 take_into(&values, &all_valid_indices, &mut builder).unwrap();
218 let actual = builder.finish();
219 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
220 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(4)));
221 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::from(Some(5)));
222
223 let mixed_valid_indices = PrimitiveArray::new(
224 buffer![0, 3, 4],
225 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
226 );
227 let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
228 take_into(&values, &mixed_valid_indices, &mut builder).unwrap();
229 let actual = builder.finish();
230 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
231 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(4)));
232 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
234
235 let all_invalid_indices = PrimitiveArray::new(
236 buffer![0, 3, 4],
237 Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
238 );
239 let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
240 take_into(&values, &all_invalid_indices, &mut builder).unwrap();
241 let actual = builder.finish();
242 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::null_typed::<i32>());
243 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<i32>());
244 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
245
246 let non_null_indices = PrimitiveArray::new(buffer![0, 3, 4], Validity::NonNullable);
247 let mut builder = PrimitiveBuilder::<i32>::new(Nullability::NonNullable);
248 take_into(&values, &non_null_indices, &mut builder).unwrap();
249 let actual = builder.finish();
250 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(1));
251 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(4));
252 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::from(5));
253 }
254}