vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7use std::sync::LazyLock;
8
9use vortex_buffer::Buffer;
10use vortex_buffer::BufferMut;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::array::ArrayView;
17use crate::arrays::Primitive;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::dict::TakeExecute;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::dtype::IntegerPType;
23use crate::dtype::NativePType;
24use crate::executor::ExecutionCtx;
25use crate::match_each_integer_ptype;
26use crate::match_each_native_ptype;
27use crate::validity::Validity;
28
29static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
32 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
33 {
34 if is_x86_feature_detected!("avx2") {
35 &avx2::TakeKernelAVX2
36 } else {
37 &TakeKernelScalar
38 }
39 }
40
41 #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
42 {
43 &TakeKernelScalar
44 }
45});
46
47trait TakeImpl: Send + Sync {
48 fn take(
49 &self,
50 array: ArrayView<'_, Primitive>,
51 indices: ArrayView<'_, Primitive>,
52 validity: Validity,
53 ) -> VortexResult<ArrayRef>;
54}
55
56struct TakeKernelScalar;
57
58impl TakeImpl for TakeKernelScalar {
59 fn take(
60 &self,
61 array: ArrayView<'_, Primitive>,
62 indices: ArrayView<'_, Primitive>,
63 validity: Validity,
64 ) -> VortexResult<ArrayRef> {
65 match_each_native_ptype!(array.ptype(), |T| {
66 match_each_integer_ptype!(indices.ptype(), |I| {
67 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
68 Ok(PrimitiveArray::new(values, validity).into_array())
69 })
70 })
71 }
72}
73
74impl TakeExecute for Primitive {
75 fn take(
76 array: ArrayView<'_, Primitive>,
77 indices: &ArrayRef,
78 ctx: &mut ExecutionCtx,
79 ) -> VortexResult<Option<ArrayRef>> {
80 let DType::Primitive(ptype, null) = indices.dtype() else {
81 vortex_bail!("Invalid indices dtype: {}", indices.dtype())
82 };
83
84 let unsigned_indices = if ptype.is_unsigned_int() {
85 indices.clone().execute::<PrimitiveArray>(ctx)?
86 } else {
87 indices
89 .clone()
90 .cast(DType::Primitive(ptype.to_unsigned(), *null))?
91 .execute::<PrimitiveArray>(ctx)?
92 };
93
94 let validity = array
95 .validity()?
96 .take(&unsigned_indices.clone().into_array())?;
97 {
99 let unsigned_indices = unsigned_indices.as_view();
100 PRIMITIVE_TAKE_KERNEL
101 .take(array, unsigned_indices, validity)
102 .map(Some)
103 }
104 }
105}
106
107#[inline(always)]
109fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
110 buffer: &[T],
111 indices: &[I],
112) -> Buffer<T> {
113 let mut result = BufferMut::with_capacity(indices.len());
117 let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
118
119 for (i, idx) in indices.iter().enumerate() {
122 unsafe { ptr.add(i).write(buffer[idx.as_()]) };
124 }
125
126 unsafe { result.set_len(indices.len()) };
128 result.freeze()
129}
130
131#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
132#[cfg(test)]
133mod test {
134 use rstest::rstest;
135 use vortex_buffer::buffer;
136 use vortex_error::VortexExpect;
137
138 use crate::IntoArray;
139 use crate::LEGACY_SESSION;
140 use crate::VortexSessionExecute;
141 use crate::arrays::BoolArray;
142 use crate::arrays::PrimitiveArray;
143 use crate::arrays::primitive::compute::take::take_primitive_scalar;
144 use crate::compute::conformance::take::test_take_conformance;
145 use crate::scalar::Scalar;
146 use crate::validity::Validity;
147
148 #[test]
149 fn test_take() {
150 let a = vec![1i32, 2, 3, 4, 5];
151 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
152 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
153 }
154
155 #[test]
156 fn test_take_with_null_indices() {
157 let values = PrimitiveArray::new(
158 buffer![1i32, 2, 3, 4, 5],
159 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
160 );
161 let indices = PrimitiveArray::new(
162 buffer![0, 3, 4],
163 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
164 );
165 let actual = values.take(indices.into_array()).unwrap();
166 assert_eq!(
167 actual
168 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
169 .vortex_expect("no fail"),
170 Scalar::from(Some(1))
171 );
172 assert_eq!(
174 actual
175 .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
176 .vortex_expect("no fail"),
177 Scalar::null_native::<i32>()
178 );
179 assert_eq!(
181 actual
182 .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
183 .vortex_expect("no fail"),
184 Scalar::null_native::<i32>()
185 );
186 }
187
188 #[rstest]
189 #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
190 #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
191 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
192 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
193 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
194 #[case(PrimitiveArray::new(
195 buffer![0, 1, 2, 3, 4, 5],
196 Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
197 ))]
198 #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
199 fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
200 test_take_conformance(&array.into_array());
201 }
202}