vortex_array/arrays/primitive/compute/take/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use num_traits::AsPrimitive;
13use vortex_buffer::Buffer;
14use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
15use vortex_error::{VortexResult, vortex_bail};
16
17use crate::arrays::PrimitiveVTable;
18use crate::arrays::primitive::PrimitiveArray;
19use crate::compute::{TakeKernel, TakeKernelAdapter, cast};
20use crate::validity::Validity;
21use crate::vtable::ValidityHelper;
22use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
23
24// Kernel selection happens on the first call to `take` and uses a combination of compile-time
25// and runtime feature detection to infer the best kernel for the platform.
26static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
27    cfg_if::cfg_if! {
28        if #[cfg(vortex_nightly)] {
29            // nightly codepath: use portable_simd kernel
30            &portable::TakeKernelPortableSimd
31        } else if #[cfg(target_arch = "x86_64")] {
32            // stable x86_64 path: use the optimized AVX2 kernel when available, falling
33            // back to scalar when not.
34            if is_x86_feature_detected!("avx2") {
35                &avx2::TakeKernelAVX2
36            } else {
37                &TakeKernelScalar
38            }
39        } else {
40            // stable all other platforms: scalar kernel
41            &TakeKernelScalar
42        }
43    }
44});
45
46trait TakeImpl: Send + Sync {
47    fn take(
48        &self,
49        array: &PrimitiveArray,
50        indices: &PrimitiveArray,
51        validity: Validity,
52    ) -> VortexResult<ArrayRef>;
53}
54
55#[allow(unused)]
56struct TakeKernelScalar;
57
58impl TakeImpl for TakeKernelScalar {
59    fn take(
60        &self,
61        array: &PrimitiveArray,
62        indices: &PrimitiveArray,
63        validity: Validity,
64    ) -> VortexResult<ArrayRef> {
65        match_each_native_ptype!(array.ptype(), |T| {
66            match_each_integer_ptype!(indices.ptype(), |I| {
67                let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
68                Ok(PrimitiveArray::new(values, validity).into_array())
69            })
70        })
71    }
72}
73
74impl TakeKernel for PrimitiveVTable {
75    fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
76        let DType::Primitive(ptype, null) = indices.dtype() else {
77            vortex_bail!("Invalid indices dtype: {}", indices.dtype())
78        };
79
80        let unsigned_indices = if ptype.is_unsigned_int() {
81            indices.to_primitive()?
82        } else {
83            // This will fail if all values cannot be converted to unsigned
84            cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive()?
85        };
86
87        let validity = array.validity().take(unsigned_indices.as_ref())?;
88        // Delegate to the best kernel based on the target CPU
89        PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity)
90    }
91}
92
93register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
94
95// Compiler may see this as unused based on enabled features
96#[allow(unused)]
97#[inline(always)]
98fn take_primitive_scalar<T: NativePType, I: NativePType + AsPrimitive<usize>>(
99    array: &[T],
100    indices: &[I],
101) -> Buffer<T> {
102    indices.iter().map(|idx| array[idx.as_()]).collect()
103}
104
105#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
106#[cfg(test)]
107mod test {
108    use rstest::rstest;
109    use vortex_buffer::buffer;
110    use vortex_scalar::Scalar;
111
112    use crate::arrays::primitive::compute::take::take_primitive_scalar;
113    use crate::arrays::{BoolArray, PrimitiveArray};
114    use crate::compute::conformance::take::test_take_conformance;
115    use crate::compute::take;
116    use crate::validity::Validity;
117    use crate::{Array, IntoArray};
118
119    #[test]
120    fn test_take() {
121        let a = vec![1i32, 2, 3, 4, 5];
122        let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
123        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
124    }
125
126    #[test]
127    fn test_take_with_null_indices() {
128        let values = PrimitiveArray::new(
129            buffer![1i32, 2, 3, 4, 5],
130            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
131        );
132        let indices = PrimitiveArray::new(
133            buffer![0, 3, 4],
134            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
135        );
136        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
137        assert_eq!(actual.scalar_at(0), Scalar::from(Some(1)));
138        // position 3 is null
139        assert_eq!(actual.scalar_at(1), Scalar::null_typed::<i32>());
140        // the third index is null
141        assert_eq!(actual.scalar_at(2), Scalar::null_typed::<i32>());
142    }
143
144    #[rstest]
145    #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
146    #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
147    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
148    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
149    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
150    #[case(PrimitiveArray::new(
151        buffer![0, 1, 2, 3, 4, 5],
152        Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
153    ))]
154    #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
155    fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
156        test_take_conformance(array.as_ref());
157    }
158}