vortex_array/arrays/primitive/compute/take/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_dtype::{
14    DType, IntegerPType, NativePType, match_each_integer_ptype, match_each_native_ptype,
15};
16use vortex_error::{VortexResult, vortex_bail};
17
18use crate::arrays::PrimitiveVTable;
19use crate::arrays::primitive::PrimitiveArray;
20use crate::compute::{TakeKernel, TakeKernelAdapter, cast};
21use crate::validity::Validity;
22use crate::vtable::ValidityHelper;
23use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
24
25// Kernel selection happens on the first call to `take` and uses a combination of compile-time
26// and runtime feature detection to infer the best kernel for the platform.
27static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
28    cfg_if::cfg_if! {
29        if #[cfg(vortex_nightly)] {
30            // nightly codepath: use portable_simd kernel
31            &portable::TakeKernelPortableSimd
32        } else if #[cfg(target_arch = "x86_64")] {
33            // stable x86_64 path: use the optimized AVX2 kernel when available, falling
34            // back to scalar when not.
35            if is_x86_feature_detected!("avx2") {
36                &avx2::TakeKernelAVX2
37            } else {
38                &TakeKernelScalar
39            }
40        } else {
41            // stable all other platforms: scalar kernel
42            &TakeKernelScalar
43        }
44    }
45});
46
47trait TakeImpl: Send + Sync {
48    fn take(
49        &self,
50        array: &PrimitiveArray,
51        indices: &PrimitiveArray,
52        validity: Validity,
53    ) -> VortexResult<ArrayRef>;
54}
55
56#[allow(unused)]
57struct TakeKernelScalar;
58
59impl TakeImpl for TakeKernelScalar {
60    fn take(
61        &self,
62        array: &PrimitiveArray,
63        indices: &PrimitiveArray,
64        validity: Validity,
65    ) -> VortexResult<ArrayRef> {
66        match_each_native_ptype!(array.ptype(), |T| {
67            match_each_integer_ptype!(indices.ptype(), |I| {
68                let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
69                Ok(PrimitiveArray::new(values, validity).into_array())
70            })
71        })
72    }
73}
74
75impl TakeKernel for PrimitiveVTable {
76    fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
77        let DType::Primitive(ptype, null) = indices.dtype() else {
78            vortex_bail!("Invalid indices dtype: {}", indices.dtype())
79        };
80
81        let unsigned_indices = if ptype.is_unsigned_int() {
82            indices.to_primitive()
83        } else {
84            // This will fail if all values cannot be converted to unsigned
85            cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive()
86        };
87
88        let validity = array.validity().take(unsigned_indices.as_ref())?;
89        // Delegate to the best kernel based on the target CPU
90        PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity)
91    }
92}
93
94register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
95
96// Compiler may see this as unused based on enabled features
97#[allow(unused)]
98#[inline(always)]
99fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
100    indices.iter().map(|idx| array[idx.as_()]).collect()
101}
102
103#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
104#[cfg(test)]
105mod test {
106    use rstest::rstest;
107    use vortex_buffer::buffer;
108    use vortex_scalar::Scalar;
109
110    use crate::arrays::primitive::compute::take::take_primitive_scalar;
111    use crate::arrays::{BoolArray, PrimitiveArray};
112    use crate::compute::conformance::take::test_take_conformance;
113    use crate::compute::take;
114    use crate::validity::Validity;
115    use crate::{Array, IntoArray};
116
117    #[test]
118    fn test_take() {
119        let a = vec![1i32, 2, 3, 4, 5];
120        let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
121        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
122    }
123
124    #[test]
125    fn test_take_with_null_indices() {
126        let values = PrimitiveArray::new(
127            buffer![1i32, 2, 3, 4, 5],
128            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
129        );
130        let indices = PrimitiveArray::new(
131            buffer![0, 3, 4],
132            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
133        );
134        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
135        assert_eq!(actual.scalar_at(0), Scalar::from(Some(1)));
136        // position 3 is null
137        assert_eq!(actual.scalar_at(1), Scalar::null_typed::<i32>());
138        // the third index is null
139        assert_eq!(actual.scalar_at(2), Scalar::null_typed::<i32>());
140    }
141
142    #[rstest]
143    #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
144    #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
145    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
146    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
147    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
148    #[case(PrimitiveArray::new(
149        buffer![0, 1, 2, 3, 4, 5],
150        Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
151    ))]
152    #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
153    fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
154        test_take_conformance(array.as_ref());
155    }
156}