vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7use std::sync::LazyLock;
8
9use vortex_buffer::Buffer;
10use vortex_buffer::BufferMut;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_mask::Mask;
14
15use crate::ArrayRef;
16use crate::IntoArray;
17use crate::array::ArrayView;
18use crate::arrays::ConstantArray;
19use crate::arrays::Primitive;
20use crate::arrays::PrimitiveArray;
21use crate::arrays::dict::TakeExecute;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::DType;
24use crate::dtype::IntegerPType;
25use crate::executor::ExecutionCtx;
26use crate::match_each_integer_ptype;
27use crate::match_each_native_ptype;
28use crate::scalar::Scalar;
29use crate::validity::Validity;
30
31static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
34 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
35 {
36 if is_x86_feature_detected!("avx2") {
37 &avx2::TakeKernelAVX2
38 } else {
39 &TakeKernelScalar
40 }
41 }
42
43 #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
44 {
45 &TakeKernelScalar
46 }
47});
48
49trait TakeImpl: Send + Sync {
50 fn take(
51 &self,
52 array: ArrayView<'_, Primitive>,
53 indices: ArrayView<'_, Primitive>,
54 validity: Validity,
55 ) -> VortexResult<ArrayRef>;
56}
57
58struct TakeKernelScalar;
59
60impl TakeImpl for TakeKernelScalar {
61 fn take(
62 &self,
63 array: ArrayView<'_, Primitive>,
64 indices: ArrayView<'_, Primitive>,
65 validity: Validity,
66 ) -> VortexResult<ArrayRef> {
67 match_each_native_ptype!(array.ptype(), |T| {
68 match_each_integer_ptype!(indices.ptype(), |I| {
69 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
70 Ok(PrimitiveArray::new(values, validity).into_array())
71 })
72 })
73 }
74}
75
76impl TakeExecute for Primitive {
77 fn take(
78 array: ArrayView<'_, Primitive>,
79 indices: &ArrayRef,
80 ctx: &mut ExecutionCtx,
81 ) -> VortexResult<Option<ArrayRef>> {
82 let DType::Primitive(ptype, null) = indices.dtype() else {
83 vortex_bail!("Invalid indices dtype: {}", indices.dtype())
84 };
85
86 let indices_validity = indices.validity()?;
87 let indices_nulls_zeroed = match indices_validity.execute_mask(indices.len(), ctx)? {
91 Mask::AllTrue(_) => indices.clone(),
92 Mask::AllFalse(_) => {
93 return Ok(Some(
94 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
95 .into_array(),
96 ));
97 }
98 Mask::Values(_) => indices
99 .clone()
100 .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
101 };
102
103 let unsigned_indices = if ptype.is_unsigned_int() {
104 indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?
105 } else {
106 indices_nulls_zeroed
108 .cast(DType::Primitive(ptype.to_unsigned(), *null))?
109 .execute::<PrimitiveArray>(ctx)?
110 };
111
112 let validity = array
113 .validity()?
114 .take(&unsigned_indices.clone().into_array())?
115 .and(indices_validity)?;
116 {
118 let unsigned_indices = unsigned_indices.as_view();
119 PRIMITIVE_TAKE_KERNEL
120 .take(array, unsigned_indices, validity)
121 .map(Some)
122 }
123 }
124}
125
126#[inline(always)]
128fn take_primitive_scalar<T: Copy, I: IntegerPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
129 let mut result = BufferMut::with_capacity(indices.len());
133 let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
134
135 for (i, idx) in indices.iter().enumerate() {
138 unsafe { ptr.add(i).write(buffer[idx.as_()]) };
140 }
141
142 unsafe { result.set_len(indices.len()) };
144 result.freeze()
145}
146
147#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
148#[cfg(test)]
149mod test {
150 use rstest::rstest;
151 use vortex_buffer::buffer;
152 use vortex_error::VortexExpect;
153
154 use crate::IntoArray;
155 use crate::VortexSessionExecute;
156 use crate::array_session;
157 use crate::arrays::BoolArray;
158 use crate::arrays::PrimitiveArray;
159 use crate::arrays::primitive::compute::take::take_primitive_scalar;
160 use crate::compute::conformance::take::test_take_conformance;
161 use crate::scalar::Scalar;
162 use crate::validity::Validity;
163
164 #[test]
165 fn test_take() {
166 let a = vec![1i32, 2, 3, 4, 5];
167 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
168 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
169 }
170
171 #[test]
172 fn test_take_with_null_indices() {
173 let values = PrimitiveArray::new(
174 buffer![1i32, 2, 3, 4, 5],
175 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
176 );
177 let indices = PrimitiveArray::new(
178 buffer![0, 3, 4],
179 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
180 );
181 let actual = values.take(indices.into_array()).unwrap();
182 assert_eq!(
183 actual
184 .execute_scalar(0, &mut array_session().create_execution_ctx())
185 .vortex_expect("no fail"),
186 Scalar::from(Some(1))
187 );
188 assert_eq!(
190 actual
191 .execute_scalar(1, &mut array_session().create_execution_ctx())
192 .vortex_expect("no fail"),
193 Scalar::null_native::<i32>()
194 );
195 assert_eq!(
197 actual
198 .execute_scalar(2, &mut array_session().create_execution_ctx())
199 .vortex_expect("no fail"),
200 Scalar::null_native::<i32>()
201 );
202 }
203
204 #[rstest]
205 #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
206 #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
207 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
208 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
209 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
210 #[case(PrimitiveArray::new(
211 buffer![0, 1, 2, 3, 4, 5],
212 Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
213 ))]
214 #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
215 fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
216 test_take_conformance(&array.into_array());
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use vortex_buffer::buffer;
223
224 use crate::IntoArray;
225 use crate::VortexSessionExecute;
226 use crate::array_session;
227 use crate::arrays::BoolArray;
228 use crate::arrays::PrimitiveArray;
229 use crate::assert_arrays_eq;
230 use crate::validity::Validity;
231
232 #[test]
233 fn take_null_index_skips_out_of_bounds_value() {
234 let mut ctx = array_session().create_execution_ctx();
235 let values = PrimitiveArray::from_iter([10i32, 20, 30]);
236 let indices = PrimitiveArray::new(
237 buffer![1u64, 3],
238 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
239 );
240
241 let taken = values.take(indices.into_array()).unwrap();
242
243 assert_arrays_eq!(
244 taken,
245 PrimitiveArray::from_option_iter([Some(20i32), None]).into_array(),
246 &mut ctx
247 );
248 }
249}