vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_buffer::BufferMut;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16
17use crate::ArrayRef;
18use crate::DynArray;
19use crate::IntoArray;
20use crate::arrays::PrimitiveArray;
21use crate::arrays::PrimitiveVTable;
22use crate::arrays::dict::TakeExecute;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::IntegerPType;
26use crate::dtype::NativePType;
27use crate::executor::ExecutionCtx;
28use crate::match_each_integer_ptype;
29use crate::match_each_native_ptype;
30use crate::validity::Validity;
31use crate::vtable::ValidityHelper;
32
33static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
36 cfg_if::cfg_if! {
37 if #[cfg(vortex_nightly)] {
38 &portable::TakeKernelPortableSimd
40 } else if #[cfg(target_arch = "x86_64")] {
41 if is_x86_feature_detected!("avx2") {
44 &avx2::TakeKernelAVX2
45 } else {
46 &TakeKernelScalar
47 }
48 } else {
49 &TakeKernelScalar
51 }
52 }
53});
54
55trait TakeImpl: Send + Sync {
56 fn take(
57 &self,
58 array: &PrimitiveArray,
59 indices: &PrimitiveArray,
60 validity: Validity,
61 ) -> VortexResult<ArrayRef>;
62}
63
64#[allow(unused)]
65struct TakeKernelScalar;
66
67impl TakeImpl for TakeKernelScalar {
68 fn take(
69 &self,
70 array: &PrimitiveArray,
71 indices: &PrimitiveArray,
72 validity: Validity,
73 ) -> VortexResult<ArrayRef> {
74 match_each_native_ptype!(array.ptype(), |T| {
75 match_each_integer_ptype!(indices.ptype(), |I| {
76 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
77 Ok(PrimitiveArray::new(values, validity).into_array())
78 })
79 })
80 }
81}
82
83impl TakeExecute for PrimitiveVTable {
84 fn take(
85 array: &PrimitiveArray,
86 indices: &ArrayRef,
87 ctx: &mut ExecutionCtx,
88 ) -> VortexResult<Option<ArrayRef>> {
89 let DType::Primitive(ptype, null) = indices.dtype() else {
90 vortex_bail!("Invalid indices dtype: {}", indices.dtype())
91 };
92
93 let unsigned_indices = if ptype.is_unsigned_int() {
94 indices.to_array().execute::<PrimitiveArray>(ctx)?
95 } else {
96 indices
98 .to_array()
99 .cast(DType::Primitive(ptype.to_unsigned(), *null))?
100 .execute::<PrimitiveArray>(ctx)?
101 };
102
103 let validity = array
104 .validity()
105 .take(&unsigned_indices.clone().into_array())?;
106 PRIMITIVE_TAKE_KERNEL
108 .take(array, &unsigned_indices, validity)
109 .map(Some)
110 }
111}
112
113#[allow(unused)]
115#[inline(always)]
116fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
117 buffer: &[T],
118 indices: &[I],
119) -> Buffer<T> {
120 let mut result = BufferMut::with_capacity(indices.len());
124 let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
125
126 for (i, idx) in indices.iter().enumerate() {
129 unsafe { ptr.add(i).write(buffer[idx.as_()]) };
131 }
132
133 unsafe { result.set_len(indices.len()) };
135 result.freeze()
136}
137
138#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
139#[cfg(test)]
140mod test {
141 use rstest::rstest;
142 use vortex_buffer::buffer;
143 use vortex_error::VortexExpect;
144
145 use crate::DynArray;
146 use crate::IntoArray;
147 use crate::arrays::BoolArray;
148 use crate::arrays::PrimitiveArray;
149 use crate::arrays::primitive::compute::take::take_primitive_scalar;
150 use crate::compute::conformance::take::test_take_conformance;
151 use crate::scalar::Scalar;
152 use crate::validity::Validity;
153
154 #[test]
155 fn test_take() {
156 let a = vec![1i32, 2, 3, 4, 5];
157 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
158 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
159 }
160
161 #[test]
162 fn test_take_with_null_indices() {
163 let values = PrimitiveArray::new(
164 buffer![1i32, 2, 3, 4, 5],
165 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
166 );
167 let indices = PrimitiveArray::new(
168 buffer![0, 3, 4],
169 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
170 );
171 let actual = values.take(indices.into_array()).unwrap();
172 assert_eq!(
173 actual.scalar_at(0).vortex_expect("no fail"),
174 Scalar::from(Some(1))
175 );
176 assert_eq!(
178 actual.scalar_at(1).vortex_expect("no fail"),
179 Scalar::null_native::<i32>()
180 );
181 assert_eq!(
183 actual.scalar_at(2).vortex_expect("no fail"),
184 Scalar::null_native::<i32>()
185 );
186 }
187
188 #[rstest]
189 #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
190 #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
191 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
192 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
193 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
194 #[case(PrimitiveArray::new(
195 buffer![0, 1, 2, 3, 4, 5],
196 Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
197 ))]
198 #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
199 fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
200 test_take_conformance(&array.into_array());
201 }
202}