vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_buffer::BufferMut;
14use vortex_dtype::DType;
15use vortex_dtype::IntegerPType;
16use vortex_dtype::NativePType;
17use vortex_dtype::match_each_integer_ptype;
18use vortex_dtype::match_each_native_ptype;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::ToCanonical;
26use crate::arrays::PrimitiveVTable;
27use crate::arrays::TakeExecute;
28use crate::arrays::primitive::PrimitiveArray;
29use crate::builtins::ArrayBuiltins;
30use crate::executor::ExecutionCtx;
31use crate::validity::Validity;
32use crate::vtable::ValidityHelper;
33
34static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
37 cfg_if::cfg_if! {
38 if #[cfg(vortex_nightly)] {
39 &portable::TakeKernelPortableSimd
41 } else if #[cfg(target_arch = "x86_64")] {
42 if is_x86_feature_detected!("avx2") {
45 &avx2::TakeKernelAVX2
46 } else {
47 &TakeKernelScalar
48 }
49 } else {
50 &TakeKernelScalar
52 }
53 }
54});
55
56trait TakeImpl: Send + Sync {
57 fn take(
58 &self,
59 array: &PrimitiveArray,
60 indices: &PrimitiveArray,
61 validity: Validity,
62 ) -> VortexResult<ArrayRef>;
63}
64
65#[allow(unused)]
66struct TakeKernelScalar;
67
68impl TakeImpl for TakeKernelScalar {
69 fn take(
70 &self,
71 array: &PrimitiveArray,
72 indices: &PrimitiveArray,
73 validity: Validity,
74 ) -> VortexResult<ArrayRef> {
75 match_each_native_ptype!(array.ptype(), |T| {
76 match_each_integer_ptype!(indices.ptype(), |I| {
77 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
78 Ok(PrimitiveArray::new(values, validity).into_array())
79 })
80 })
81 }
82}
83
84impl TakeExecute for PrimitiveVTable {
85 fn take(
86 array: &PrimitiveArray,
87 indices: &dyn Array,
88 _ctx: &mut ExecutionCtx,
89 ) -> VortexResult<Option<ArrayRef>> {
90 let DType::Primitive(ptype, null) = indices.dtype() else {
91 vortex_bail!("Invalid indices dtype: {}", indices.dtype())
92 };
93
94 let unsigned_indices = if ptype.is_unsigned_int() {
95 indices.to_primitive()
96 } else {
97 indices
99 .to_array()
100 .cast(DType::Primitive(ptype.to_unsigned(), *null))?
101 .to_primitive()
102 };
103
104 let validity = array.validity().take(unsigned_indices.as_ref())?;
105 PRIMITIVE_TAKE_KERNEL
107 .take(array, &unsigned_indices, validity)
108 .map(Some)
109 }
110}
111
112#[allow(unused)]
114#[inline(always)]
115fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
116 buffer: &[T],
117 indices: &[I],
118) -> Buffer<T> {
119 let mut result = BufferMut::with_capacity(indices.len());
123 let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
124
125 for (i, idx) in indices.iter().enumerate() {
128 unsafe { ptr.add(i).write(buffer[idx.as_()]) };
130 }
131
132 unsafe { result.set_len(indices.len()) };
134 result.freeze()
135}
136
137#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
138#[cfg(test)]
139mod test {
140 use rstest::rstest;
141 use vortex_buffer::buffer;
142 use vortex_error::VortexExpect;
143
144 use crate::Array;
145 use crate::IntoArray;
146 use crate::arrays::BoolArray;
147 use crate::arrays::PrimitiveArray;
148 use crate::arrays::primitive::compute::take::take_primitive_scalar;
149 use crate::compute::conformance::take::test_take_conformance;
150 use crate::scalar::Scalar;
151 use crate::validity::Validity;
152
153 #[test]
154 fn test_take() {
155 let a = vec![1i32, 2, 3, 4, 5];
156 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
157 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
158 }
159
160 #[test]
161 fn test_take_with_null_indices() {
162 let values = PrimitiveArray::new(
163 buffer![1i32, 2, 3, 4, 5],
164 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
165 );
166 let indices = PrimitiveArray::new(
167 buffer![0, 3, 4],
168 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
169 );
170 let actual = values.take(indices.to_array()).unwrap();
171 assert_eq!(
172 actual.scalar_at(0).vortex_expect("no fail"),
173 Scalar::from(Some(1))
174 );
175 assert_eq!(
177 actual.scalar_at(1).vortex_expect("no fail"),
178 Scalar::null_native::<i32>()
179 );
180 assert_eq!(
182 actual.scalar_at(2).vortex_expect("no fail"),
183 Scalar::null_native::<i32>()
184 );
185 }
186
187 #[rstest]
188 #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
189 #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
190 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
191 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
192 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
193 #[case(PrimitiveArray::new(
194 buffer![0, 1, 2, 3, 4, 5],
195 Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
196 ))]
197 #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
198 fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
199 test_take_conformance(array.as_ref());
200 }
201}