Skip to main content

vortex_array/arrays/primitive/compute/take/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_buffer::BufferMut;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::IntoArray;
20use crate::arrays::PrimitiveVTable;
21use crate::arrays::TakeExecute;
22use crate::arrays::primitive::PrimitiveArray;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::IntegerPType;
26use crate::dtype::NativePType;
27use crate::executor::ExecutionCtx;
28use crate::match_each_integer_ptype;
29use crate::match_each_native_ptype;
30use crate::validity::Validity;
31use crate::vtable::ValidityHelper;
32
33// Kernel selection happens on the first call to `take` and uses a combination of compile-time
34// and runtime feature detection to infer the best kernel for the platform.
35static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
36    cfg_if::cfg_if! {
37        if #[cfg(vortex_nightly)] {
38            // nightly codepath: use portable_simd kernel
39            &portable::TakeKernelPortableSimd
40        } else if #[cfg(target_arch = "x86_64")] {
41            // stable x86_64 path: use the optimized AVX2 kernel when available, falling
42            // back to scalar when not.
43            if is_x86_feature_detected!("avx2") {
44                &avx2::TakeKernelAVX2
45            } else {
46                &TakeKernelScalar
47            }
48        } else {
49            // stable all other platforms: scalar kernel
50            &TakeKernelScalar
51        }
52    }
53});
54
55trait TakeImpl: Send + Sync {
56    fn take(
57        &self,
58        array: &PrimitiveArray,
59        indices: &PrimitiveArray,
60        validity: Validity,
61    ) -> VortexResult<ArrayRef>;
62}
63
64#[allow(unused)]
65struct TakeKernelScalar;
66
67impl TakeImpl for TakeKernelScalar {
68    fn take(
69        &self,
70        array: &PrimitiveArray,
71        indices: &PrimitiveArray,
72        validity: Validity,
73    ) -> VortexResult<ArrayRef> {
74        match_each_native_ptype!(array.ptype(), |T| {
75            match_each_integer_ptype!(indices.ptype(), |I| {
76                let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
77                Ok(PrimitiveArray::new(values, validity).into_array())
78            })
79        })
80    }
81}
82
83impl TakeExecute for PrimitiveVTable {
84    fn take(
85        array: &PrimitiveArray,
86        indices: &ArrayRef,
87        ctx: &mut ExecutionCtx,
88    ) -> VortexResult<Option<ArrayRef>> {
89        let DType::Primitive(ptype, null) = indices.dtype() else {
90            vortex_bail!("Invalid indices dtype: {}", indices.dtype())
91        };
92
93        let unsigned_indices = if ptype.is_unsigned_int() {
94            indices.to_array().execute::<PrimitiveArray>(ctx)?
95        } else {
96            // This will fail if all values cannot be converted to unsigned
97            indices
98                .to_array()
99                .cast(DType::Primitive(ptype.to_unsigned(), *null))?
100                .execute::<PrimitiveArray>(ctx)?
101        };
102
103        let validity = array.validity().take(&unsigned_indices.to_array())?;
104        // Delegate to the best kernel based on the target CPU
105        PRIMITIVE_TAKE_KERNEL
106            .take(array, &unsigned_indices, validity)
107            .map(Some)
108    }
109}
110
111// Compiler may see this as unused based on enabled features
112#[allow(unused)]
113#[inline(always)]
114fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
115    buffer: &[T],
116    indices: &[I],
117) -> Buffer<T> {
118    // NB: The simpler `indices.iter().map(|idx| buffer[idx.as_()]).collect()` generates suboptimal
119    // assembly where the buffer length is repeatedly loaded from the stack on each iteration.
120
121    let mut result = BufferMut::with_capacity(indices.len());
122    let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
123
124    // This explicit loop with pointer writes keeps the length in a register and avoids per-element
125    // capacity checks from `push()`.
126    for (i, idx) in indices.iter().enumerate() {
127        // SAFETY: We reserved `indices.len()` capacity, so `ptr.add(i)` is valid.
128        unsafe { ptr.add(i).write(buffer[idx.as_()]) };
129    }
130
131    // SAFETY: We just wrote exactly `indices.len()` elements.
132    unsafe { result.set_len(indices.len()) };
133    result.freeze()
134}
135
136#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
137#[cfg(test)]
138mod test {
139    use rstest::rstest;
140    use vortex_buffer::buffer;
141    use vortex_error::VortexExpect;
142
143    use crate::Array;
144    use crate::IntoArray;
145    use crate::arrays::BoolArray;
146    use crate::arrays::PrimitiveArray;
147    use crate::arrays::primitive::compute::take::take_primitive_scalar;
148    use crate::compute::conformance::take::test_take_conformance;
149    use crate::scalar::Scalar;
150    use crate::validity::Validity;
151
152    #[test]
153    fn test_take() {
154        let a = vec![1i32, 2, 3, 4, 5];
155        let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
156        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
157    }
158
159    #[test]
160    fn test_take_with_null_indices() {
161        let values = PrimitiveArray::new(
162            buffer![1i32, 2, 3, 4, 5],
163            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
164        );
165        let indices = PrimitiveArray::new(
166            buffer![0, 3, 4],
167            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
168        );
169        let actual = values.take(indices.to_array()).unwrap();
170        assert_eq!(
171            actual.scalar_at(0).vortex_expect("no fail"),
172            Scalar::from(Some(1))
173        );
174        // position 3 is null
175        assert_eq!(
176            actual.scalar_at(1).vortex_expect("no fail"),
177            Scalar::null_native::<i32>()
178        );
179        // the third index is null
180        assert_eq!(
181            actual.scalar_at(2).vortex_expect("no fail"),
182            Scalar::null_native::<i32>()
183        );
184    }
185
186    #[rstest]
187    #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
188    #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
189    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
190    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
191    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
192    #[case(PrimitiveArray::new(
193        buffer![0, 1, 2, 3, 4, 5],
194        Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
195    ))]
196    #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
197    fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
198        test_take_conformance(&array.to_array());
199    }
200}