Skip to main content

vortex_array/arrays/primitive/compute/take/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7use std::sync::LazyLock;
8
9use vortex_buffer::Buffer;
10use vortex_buffer::BufferMut;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_mask::Mask;
14
15use crate::ArrayRef;
16use crate::IntoArray;
17use crate::array::ArrayView;
18use crate::arrays::ConstantArray;
19use crate::arrays::Primitive;
20use crate::arrays::PrimitiveArray;
21use crate::arrays::dict::TakeExecute;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::DType;
24use crate::dtype::IntegerPType;
25use crate::executor::ExecutionCtx;
26use crate::match_each_integer_ptype;
27use crate::match_each_native_ptype;
28use crate::scalar::Scalar;
29use crate::validity::Validity;
30
31// Kernel selection happens on the first call to `take` and uses a combination of compile-time
32// and runtime feature detection to infer the best kernel for the platform.
33static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
34    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
35    {
36        if is_x86_feature_detected!("avx2") {
37            &avx2::TakeKernelAVX2
38        } else {
39            &TakeKernelScalar
40        }
41    }
42
43    #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
44    {
45        &TakeKernelScalar
46    }
47});
48
49trait TakeImpl: Send + Sync {
50    fn take(
51        &self,
52        array: ArrayView<'_, Primitive>,
53        indices: ArrayView<'_, Primitive>,
54        validity: Validity,
55    ) -> VortexResult<ArrayRef>;
56}
57
58struct TakeKernelScalar;
59
60impl TakeImpl for TakeKernelScalar {
61    fn take(
62        &self,
63        array: ArrayView<'_, Primitive>,
64        indices: ArrayView<'_, Primitive>,
65        validity: Validity,
66    ) -> VortexResult<ArrayRef> {
67        match_each_native_ptype!(array.ptype(), |T| {
68            match_each_integer_ptype!(indices.ptype(), |I| {
69                let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
70                Ok(PrimitiveArray::new(values, validity).into_array())
71            })
72        })
73    }
74}
75
76impl TakeExecute for Primitive {
77    fn take(
78        array: ArrayView<'_, Primitive>,
79        indices: &ArrayRef,
80        ctx: &mut ExecutionCtx,
81    ) -> VortexResult<Option<ArrayRef>> {
82        let DType::Primitive(ptype, null) = indices.dtype() else {
83            vortex_bail!("Invalid indices dtype: {}", indices.dtype())
84        };
85
86        let indices_validity = indices.validity()?;
87        // Null index lanes are semantically ignored, but their physical values may be out of
88        // bounds. Redirect those lanes to zero for the cast/gather, then restore the original index
89        // validity below.
90        let indices_nulls_zeroed = match indices_validity.execute_mask(indices.len(), ctx)? {
91            Mask::AllTrue(_) => indices.clone(),
92            Mask::AllFalse(_) => {
93                return Ok(Some(
94                    ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
95                        .into_array(),
96                ));
97            }
98            Mask::Values(_) => indices
99                .clone()
100                .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
101        };
102
103        let unsigned_indices = if ptype.is_unsigned_int() {
104            indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?
105        } else {
106            // This will fail if all values cannot be converted to unsigned
107            indices_nulls_zeroed
108                .cast(DType::Primitive(ptype.to_unsigned(), *null))?
109                .execute::<PrimitiveArray>(ctx)?
110        };
111
112        let validity = array
113            .validity()?
114            .take(&unsigned_indices.clone().into_array())?
115            .and(indices_validity)?;
116        // Delegate to the best kernel based on the target CPU
117        {
118            let unsigned_indices = unsigned_indices.as_view();
119            PRIMITIVE_TAKE_KERNEL
120                .take(array, unsigned_indices, validity)
121                .map(Some)
122        }
123    }
124}
125
126// Compiler may see this as unused based on enabled features
127#[inline(always)]
128fn take_primitive_scalar<T: Copy, I: IntegerPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
129    // NB: The simpler `indices.iter().map(|idx| buffer[idx.as_()]).collect()` generates suboptimal
130    // assembly where the buffer length is repeatedly loaded from the stack on each iteration.
131
132    let mut result = BufferMut::with_capacity(indices.len());
133    let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
134
135    // This explicit loop with pointer writes keeps the length in a register and avoids per-element
136    // capacity checks from `push()`.
137    for (i, idx) in indices.iter().enumerate() {
138        // SAFETY: We reserved `indices.len()` capacity, so `ptr.add(i)` is valid.
139        unsafe { ptr.add(i).write(buffer[idx.as_()]) };
140    }
141
142    // SAFETY: We just wrote exactly `indices.len()` elements.
143    unsafe { result.set_len(indices.len()) };
144    result.freeze()
145}
146
147#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
148#[cfg(test)]
149mod test {
150    use rstest::rstest;
151    use vortex_buffer::buffer;
152    use vortex_error::VortexExpect;
153
154    use crate::IntoArray;
155    use crate::VortexSessionExecute;
156    use crate::array_session;
157    use crate::arrays::BoolArray;
158    use crate::arrays::PrimitiveArray;
159    use crate::arrays::primitive::compute::take::take_primitive_scalar;
160    use crate::compute::conformance::take::test_take_conformance;
161    use crate::scalar::Scalar;
162    use crate::validity::Validity;
163
164    #[test]
165    fn test_take() {
166        let a = vec![1i32, 2, 3, 4, 5];
167        let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
168        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
169    }
170
171    #[test]
172    fn test_take_with_null_indices() {
173        let values = PrimitiveArray::new(
174            buffer![1i32, 2, 3, 4, 5],
175            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
176        );
177        let indices = PrimitiveArray::new(
178            buffer![0, 3, 4],
179            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
180        );
181        let actual = values.take(indices.into_array()).unwrap();
182        assert_eq!(
183            actual
184                .execute_scalar(0, &mut array_session().create_execution_ctx())
185                .vortex_expect("no fail"),
186            Scalar::from(Some(1))
187        );
188        // position 3 is null
189        assert_eq!(
190            actual
191                .execute_scalar(1, &mut array_session().create_execution_ctx())
192                .vortex_expect("no fail"),
193            Scalar::null_native::<i32>()
194        );
195        // the third index is null
196        assert_eq!(
197            actual
198                .execute_scalar(2, &mut array_session().create_execution_ctx())
199                .vortex_expect("no fail"),
200            Scalar::null_native::<i32>()
201        );
202    }
203
204    #[rstest]
205    #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
206    #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
207    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
208    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
209    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
210    #[case(PrimitiveArray::new(
211        buffer![0, 1, 2, 3, 4, 5],
212        Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
213    ))]
214    #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
215    fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
216        test_take_conformance(&array.into_array());
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use vortex_buffer::buffer;
223
224    use crate::IntoArray;
225    use crate::VortexSessionExecute;
226    use crate::array_session;
227    use crate::arrays::BoolArray;
228    use crate::arrays::PrimitiveArray;
229    use crate::assert_arrays_eq;
230    use crate::validity::Validity;
231
232    #[test]
233    fn take_null_index_skips_out_of_bounds_value() {
234        let mut ctx = array_session().create_execution_ctx();
235        let values = PrimitiveArray::from_iter([10i32, 20, 30]);
236        let indices = PrimitiveArray::new(
237            buffer![1u64, 3],
238            Validity::Array(BoolArray::from_iter([true, false]).into_array()),
239        );
240
241        let taken = values.take(indices.into_array()).unwrap();
242
243        assert_arrays_eq!(
244            taken,
245            PrimitiveArray::from_option_iter([Some(20i32), None]).into_array(),
246            &mut ctx
247        );
248    }
249}