vortex_array/arrays/primitive/compute/
take.rs1use std::simd;
2
3use num_traits::AsPrimitive;
4use simd::num::SimdUint;
5use vortex_buffer::{Alignment, Buffer, BufferMut};
6use vortex_dtype::{
7 NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
8 match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
9};
10use vortex_error::VortexResult;
11
12use crate::arrays::PrimitiveVTable;
13use crate::arrays::primitive::PrimitiveArray;
14use crate::compute::{TakeKernel, TakeKernelAdapter};
15use crate::vtable::ValidityHelper;
16use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
17
18impl TakeKernel for PrimitiveVTable {
19 #[allow(clippy::cognitive_complexity)]
20 fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21 let indices = indices.to_primitive()?;
22 let validity = array.validity().take(indices.as_ref())?;
23
24 if array.ptype() != PType::F16
25 && indices.dtype().is_unsigned_int()
26 && indices.all_valid()?
27 && array.all_valid()?
28 {
29 match_each_unsigned_integer_ptype!(indices.ptype(), |C| {
31 match_each_native_simd_ptype!(array.ptype(), |V| {
32 let decoded = take_primitive_simd::<C, V, 64>(
35 indices.as_slice(),
36 array.as_slice(),
37 array.dtype().nullability() | indices.dtype().nullability(),
38 );
39
40 return Ok(decoded.into_array()) as VortexResult<ArrayRef>;
41 })
42 });
43 }
44
45 match_each_native_ptype!(array.ptype(), |T| {
46 match_each_integer_ptype!(indices.ptype(), |I| {
47 let values = take_primitive(array.as_slice::<T>(), indices.as_slice::<I>());
48 Ok(PrimitiveArray::new(values, validity).into_array())
49 })
50 })
51 }
52}
53
54register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
55
56fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
57 array: &[T],
58 indices: &[I],
59) -> Buffer<T> {
60 indices.iter().map(|idx| array[idx.as_()]).collect()
61}
62
63fn take_primitive_simd<I, V, const LANE_COUNT: usize>(
79 indices: &[I],
80 values: &[V],
81 nullability: Nullability,
82) -> PrimitiveArray
83where
84 I: simd::SimdElement + AsPrimitive<usize>,
85 V: simd::SimdElement + NativePType,
86 simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
87 simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
88{
89 let indices_len = indices.len();
90
91 let mut buffer = BufferMut::<V>::with_capacity_aligned(
92 indices_len,
93 Alignment::of::<simd::Simd<V, LANE_COUNT>>(),
94 );
95
96 let buf_slice = buffer.spare_capacity_mut();
97
98 for chunk_idx in 0..(indices_len / LANE_COUNT) {
99 let offset = chunk_idx * LANE_COUNT;
100 let mask = simd::Mask::from_bitmask(u64::MAX);
101 let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
102
103 unsafe {
104 let selection = simd::Simd::gather_select_unchecked(
105 values,
106 mask,
107 codes_chunk.cast::<usize>(),
108 simd::Simd::<V, LANE_COUNT>::default(),
109 );
110
111 selection.store_select_ptr(buf_slice.as_mut_ptr().add(offset) as *mut V, mask.cast());
112 }
113 }
114
115 for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
116 unsafe {
117 buf_slice
118 .get_unchecked_mut(idx)
119 .write(values[indices[idx].as_()]);
120 }
121 }
122
123 unsafe {
124 buffer.set_len(indices_len);
125 }
126
127 PrimitiveArray::new(buffer.freeze(), nullability.into())
128}
129
130#[cfg(test)]
131mod test {
132 use vortex_buffer::buffer;
133 use vortex_scalar::Scalar;
134
135 use crate::arrays::primitive::compute::take::take_primitive;
136 use crate::arrays::{BoolArray, PrimitiveArray};
137 use crate::compute::take;
138 use crate::validity::Validity;
139 use crate::{Array, IntoArray};
140
141 #[test]
142 fn test_take() {
143 let a = vec![1i32, 2, 3, 4, 5];
144 let result = take_primitive(&a, &[0, 0, 4, 2]);
145 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
146 }
147
148 #[test]
149 fn test_take_with_null_indices() {
150 let values = PrimitiveArray::new(
151 buffer![1i32, 2, 3, 4, 5],
152 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
153 );
154 let indices = PrimitiveArray::new(
155 buffer![0, 3, 4],
156 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
157 );
158 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
159 assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(1)));
160 assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<i32>());
162 assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<i32>());
164 }
165}