vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_buffer::BufferMut;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16
17use crate::ArrayRef;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::arrays::Primitive;
21use crate::arrays::PrimitiveArray;
22use crate::arrays::dict::TakeExecute;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::IntegerPType;
26use crate::dtype::NativePType;
27use crate::executor::ExecutionCtx;
28use crate::match_each_integer_ptype;
29use crate::match_each_native_ptype;
30use crate::validity::Validity;
31
32static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
35 cfg_if::cfg_if! {
36 if #[cfg(vortex_nightly)] {
37 &portable::TakeKernelPortableSimd
39 } else if #[cfg(target_arch = "x86_64")] {
40 if is_x86_feature_detected!("avx2") {
43 &avx2::TakeKernelAVX2
44 } else {
45 &TakeKernelScalar
46 }
47 } else {
48 &TakeKernelScalar
50 }
51 }
52});
53
54trait TakeImpl: Send + Sync {
55 fn take(
56 &self,
57 array: ArrayView<'_, Primitive>,
58 indices: ArrayView<'_, Primitive>,
59 validity: Validity,
60 ) -> VortexResult<ArrayRef>;
61}
62
63#[allow(unused)]
64struct TakeKernelScalar;
65
66impl TakeImpl for TakeKernelScalar {
67 fn take(
68 &self,
69 array: ArrayView<'_, Primitive>,
70 indices: ArrayView<'_, Primitive>,
71 validity: Validity,
72 ) -> VortexResult<ArrayRef> {
73 match_each_native_ptype!(array.ptype(), |T| {
74 match_each_integer_ptype!(indices.ptype(), |I| {
75 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
76 Ok(PrimitiveArray::new(values, validity).into_array())
77 })
78 })
79 }
80}
81
82impl TakeExecute for Primitive {
83 fn take(
84 array: ArrayView<'_, Primitive>,
85 indices: &ArrayRef,
86 ctx: &mut ExecutionCtx,
87 ) -> VortexResult<Option<ArrayRef>> {
88 let DType::Primitive(ptype, null) = indices.dtype() else {
89 vortex_bail!("Invalid indices dtype: {}", indices.dtype())
90 };
91
92 let unsigned_indices = if ptype.is_unsigned_int() {
93 indices.clone().execute::<PrimitiveArray>(ctx)?
94 } else {
95 indices
97 .clone()
98 .cast(DType::Primitive(ptype.to_unsigned(), *null))?
99 .execute::<PrimitiveArray>(ctx)?
100 };
101
102 let validity = array
103 .validity()?
104 .take(&unsigned_indices.clone().into_array())?;
105 {
107 let unsigned_indices = unsigned_indices.as_view();
108 PRIMITIVE_TAKE_KERNEL
109 .take(array, unsigned_indices, validity)
110 .map(Some)
111 }
112 }
113}
114
115#[allow(unused)]
117#[inline(always)]
118fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
119 buffer: &[T],
120 indices: &[I],
121) -> Buffer<T> {
122 let mut result = BufferMut::with_capacity(indices.len());
126 let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
127
128 for (i, idx) in indices.iter().enumerate() {
131 unsafe { ptr.add(i).write(buffer[idx.as_()]) };
133 }
134
135 unsafe { result.set_len(indices.len()) };
137 result.freeze()
138}
139
140#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
141#[cfg(test)]
142mod test {
143 use rstest::rstest;
144 use vortex_buffer::buffer;
145 use vortex_error::VortexExpect;
146
147 use crate::IntoArray;
148 use crate::arrays::BoolArray;
149 use crate::arrays::PrimitiveArray;
150 use crate::arrays::primitive::compute::take::take_primitive_scalar;
151 use crate::compute::conformance::take::test_take_conformance;
152 use crate::scalar::Scalar;
153 use crate::validity::Validity;
154
155 #[test]
156 fn test_take() {
157 let a = vec![1i32, 2, 3, 4, 5];
158 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
159 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
160 }
161
162 #[test]
163 fn test_take_with_null_indices() {
164 let values = PrimitiveArray::new(
165 buffer![1i32, 2, 3, 4, 5],
166 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
167 );
168 let indices = PrimitiveArray::new(
169 buffer![0, 3, 4],
170 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
171 );
172 let actual = values.take(indices.into_array()).unwrap();
173 assert_eq!(
174 actual.scalar_at(0).vortex_expect("no fail"),
175 Scalar::from(Some(1))
176 );
177 assert_eq!(
179 actual.scalar_at(1).vortex_expect("no fail"),
180 Scalar::null_native::<i32>()
181 );
182 assert_eq!(
184 actual.scalar_at(2).vortex_expect("no fail"),
185 Scalar::null_native::<i32>()
186 );
187 }
188
189 #[rstest]
190 #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
191 #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
192 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
193 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
194 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
195 #[case(PrimitiveArray::new(
196 buffer![0, 1, 2, 3, 4, 5],
197 Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
198 ))]
199 #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
200 fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
201 test_take_conformance(&array.into_array());
202 }
203}