Skip to main content

vortex_array/arrays/primitive/compute/take/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_buffer::BufferMut;
14use vortex_dtype::DType;
15use vortex_dtype::IntegerPType;
16use vortex_dtype::NativePType;
17use vortex_dtype::match_each_integer_ptype;
18use vortex_dtype::match_each_native_ptype;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::ToCanonical;
26use crate::arrays::PrimitiveVTable;
27use crate::arrays::TakeExecute;
28use crate::arrays::primitive::PrimitiveArray;
29use crate::builtins::ArrayBuiltins;
30use crate::executor::ExecutionCtx;
31use crate::validity::Validity;
32use crate::vtable::ValidityHelper;
33
34// Kernel selection happens on the first call to `take` and uses a combination of compile-time
35// and runtime feature detection to infer the best kernel for the platform.
36static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
37    cfg_if::cfg_if! {
38        if #[cfg(vortex_nightly)] {
39            // nightly codepath: use portable_simd kernel
40            &portable::TakeKernelPortableSimd
41        } else if #[cfg(target_arch = "x86_64")] {
42            // stable x86_64 path: use the optimized AVX2 kernel when available, falling
43            // back to scalar when not.
44            if is_x86_feature_detected!("avx2") {
45                &avx2::TakeKernelAVX2
46            } else {
47                &TakeKernelScalar
48            }
49        } else {
50            // stable all other platforms: scalar kernel
51            &TakeKernelScalar
52        }
53    }
54});
55
56trait TakeImpl: Send + Sync {
57    fn take(
58        &self,
59        array: &PrimitiveArray,
60        indices: &PrimitiveArray,
61        validity: Validity,
62    ) -> VortexResult<ArrayRef>;
63}
64
65#[allow(unused)]
66struct TakeKernelScalar;
67
68impl TakeImpl for TakeKernelScalar {
69    fn take(
70        &self,
71        array: &PrimitiveArray,
72        indices: &PrimitiveArray,
73        validity: Validity,
74    ) -> VortexResult<ArrayRef> {
75        match_each_native_ptype!(array.ptype(), |T| {
76            match_each_integer_ptype!(indices.ptype(), |I| {
77                let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
78                Ok(PrimitiveArray::new(values, validity).into_array())
79            })
80        })
81    }
82}
83
84impl TakeExecute for PrimitiveVTable {
85    fn take(
86        array: &PrimitiveArray,
87        indices: &dyn Array,
88        _ctx: &mut ExecutionCtx,
89    ) -> VortexResult<Option<ArrayRef>> {
90        let DType::Primitive(ptype, null) = indices.dtype() else {
91            vortex_bail!("Invalid indices dtype: {}", indices.dtype())
92        };
93
94        let unsigned_indices = if ptype.is_unsigned_int() {
95            indices.to_primitive()
96        } else {
97            // This will fail if all values cannot be converted to unsigned
98            indices
99                .to_array()
100                .cast(DType::Primitive(ptype.to_unsigned(), *null))?
101                .to_primitive()
102        };
103
104        let validity = array.validity().take(unsigned_indices.as_ref())?;
105        // Delegate to the best kernel based on the target CPU
106        PRIMITIVE_TAKE_KERNEL
107            .take(array, &unsigned_indices, validity)
108            .map(Some)
109    }
110}
111
112// Compiler may see this as unused based on enabled features
113#[allow(unused)]
114#[inline(always)]
115fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
116    buffer: &[T],
117    indices: &[I],
118) -> Buffer<T> {
119    // NB: The simpler `indices.iter().map(|idx| buffer[idx.as_()]).collect()` generates suboptimal
120    // assembly where the buffer length is repeatedly loaded from the stack on each iteration.
121
122    let mut result = BufferMut::with_capacity(indices.len());
123    let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
124
125    // This explicit loop with pointer writes keeps the length in a register and avoids per-element
126    // capacity checks from `push()`.
127    for (i, idx) in indices.iter().enumerate() {
128        // SAFETY: We reserved `indices.len()` capacity, so `ptr.add(i)` is valid.
129        unsafe { ptr.add(i).write(buffer[idx.as_()]) };
130    }
131
132    // SAFETY: We just wrote exactly `indices.len()` elements.
133    unsafe { result.set_len(indices.len()) };
134    result.freeze()
135}
136
137#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
138#[cfg(test)]
139mod test {
140    use rstest::rstest;
141    use vortex_buffer::buffer;
142    use vortex_error::VortexExpect;
143
144    use crate::Array;
145    use crate::IntoArray;
146    use crate::arrays::BoolArray;
147    use crate::arrays::PrimitiveArray;
148    use crate::arrays::primitive::compute::take::take_primitive_scalar;
149    use crate::compute::conformance::take::test_take_conformance;
150    use crate::scalar::Scalar;
151    use crate::validity::Validity;
152
153    #[test]
154    fn test_take() {
155        let a = vec![1i32, 2, 3, 4, 5];
156        let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
157        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
158    }
159
160    #[test]
161    fn test_take_with_null_indices() {
162        let values = PrimitiveArray::new(
163            buffer![1i32, 2, 3, 4, 5],
164            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
165        );
166        let indices = PrimitiveArray::new(
167            buffer![0, 3, 4],
168            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
169        );
170        let actual = values.take(indices.to_array()).unwrap();
171        assert_eq!(
172            actual.scalar_at(0).vortex_expect("no fail"),
173            Scalar::from(Some(1))
174        );
175        // position 3 is null
176        assert_eq!(
177            actual.scalar_at(1).vortex_expect("no fail"),
178            Scalar::null_native::<i32>()
179        );
180        // the third index is null
181        assert_eq!(
182            actual.scalar_at(2).vortex_expect("no fail"),
183            Scalar::null_native::<i32>()
184        );
185    }
186
187    #[rstest]
188    #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
189    #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
190    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
191    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
192    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
193    #[case(PrimitiveArray::new(
194        buffer![0, 1, 2, 3, 4, 5],
195        Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
196    ))]
197    #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
198    fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
199        test_take_conformance(array.as_ref());
200    }
201}