vortex_array/arrays/primitive/compute/take/
mod.rs

1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
2mod avx2;
3
4#[cfg(feature = "nightly")]
5mod portable;
6
7use std::sync::LazyLock;
8
9use num_traits::AsPrimitive;
10use vortex_buffer::Buffer;
11use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
12use vortex_error::{VortexResult, vortex_bail};
13
14use crate::arrays::PrimitiveVTable;
15use crate::arrays::primitive::PrimitiveArray;
16use crate::compute::{TakeKernel, TakeKernelAdapter, cast};
17use crate::validity::Validity;
18use crate::vtable::ValidityHelper;
19use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
20
21// Kernel selection happens on the first call to `take` and uses a combination of compile-time
22// and runtime feature detection to infer the best kernel for the platform.
23static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
24    cfg_if::cfg_if! {
25        if #[cfg(feature = "nightly")] {
26            // nightly codepath: use portable_simd kernel
27            &portable::TakeKernelPortableSimd
28        } else if #[cfg(target_arch = "x86_64")] {
29            // stable x86_64 path: use the optimized AVX2 kernel when available, falling
30            // back to scalar when not.
31            if is_x86_feature_detected!("avx2") {
32                &avx2::TakeKernelAVX2
33            } else {
34                &TakeKernelScalar
35            }
36        } else {
37            // stable all other platforms: scalar kernel
38            &TakeKernelScalar
39        }
40    }
41});
42
43trait TakeImpl: Send + Sync {
44    fn take(
45        &self,
46        array: &PrimitiveArray,
47        indices: &PrimitiveArray,
48        validity: Validity,
49    ) -> VortexResult<ArrayRef>;
50}
51
52#[allow(unused)]
53struct TakeKernelScalar;
54
55impl TakeImpl for TakeKernelScalar {
56    fn take(
57        &self,
58        array: &PrimitiveArray,
59        indices: &PrimitiveArray,
60        validity: Validity,
61    ) -> VortexResult<ArrayRef> {
62        match_each_native_ptype!(array.ptype(), |T| {
63            match_each_integer_ptype!(indices.ptype(), |I| {
64                let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
65                Ok(PrimitiveArray::new(values, validity).into_array())
66            })
67        })
68    }
69}
70
71impl TakeKernel for PrimitiveVTable {
72    fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
73        let unsigned_indices = match indices.dtype() {
74            DType::Primitive(p, n) => {
75                if p.is_unsigned_int() {
76                    indices.to_primitive()?
77                } else {
78                    // This will fail if all values cannot be converted to unsigned
79                    cast(indices, &DType::Primitive(p.to_unsigned(), *n))?.to_primitive()?
80                }
81            }
82            _ => vortex_bail!("Invalid indices dtype: {}", indices.dtype()),
83        };
84        let validity = array.validity().take(unsigned_indices.as_ref())?;
85        // Delegate to the best kernel based on the target CPU
86        PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity)
87    }
88}
89
90register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
91
92// Compiler may see this as unused based on enabled features
93#[allow(unused)]
94#[inline(always)]
95fn take_primitive_scalar<T: NativePType, I: NativePType + AsPrimitive<usize>>(
96    array: &[T],
97    indices: &[I],
98) -> Buffer<T> {
99    indices.iter().map(|idx| array[idx.as_()]).collect()
100}
101
102#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
103#[cfg(test)]
104mod test {
105    use vortex_buffer::buffer;
106    use vortex_scalar::Scalar;
107
108    use crate::arrays::primitive::compute::take::take_primitive_scalar;
109    use crate::arrays::{BoolArray, PrimitiveArray};
110    use crate::compute::take;
111    use crate::validity::Validity;
112    use crate::{Array, IntoArray};
113
114    #[test]
115    fn test_take() {
116        let a = vec![1i32, 2, 3, 4, 5];
117        let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
118        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
119    }
120
121    #[test]
122    fn test_take_with_null_indices() {
123        let values = PrimitiveArray::new(
124            buffer![1i32, 2, 3, 4, 5],
125            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
126        );
127        let indices = PrimitiveArray::new(
128            buffer![0, 3, 4],
129            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
130        );
131        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
132        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(1)));
133        // position 3 is null
134        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<i32>());
135        // the third index is null
136        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<i32>());
137    }
138}