Skip to main content

vortex_array/arrays/primitive/compute/take/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7use std::sync::LazyLock;
8
9use vortex_buffer::Buffer;
10use vortex_buffer::BufferMut;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::array::ArrayView;
17use crate::arrays::Primitive;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::dict::TakeExecute;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::dtype::IntegerPType;
23use crate::dtype::NativePType;
24use crate::executor::ExecutionCtx;
25use crate::match_each_integer_ptype;
26use crate::match_each_native_ptype;
27use crate::validity::Validity;
28
29// Kernel selection happens on the first call to `take` and uses a combination of compile-time
30// and runtime feature detection to infer the best kernel for the platform.
31static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
32    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
33    {
34        if is_x86_feature_detected!("avx2") {
35            &avx2::TakeKernelAVX2
36        } else {
37            &TakeKernelScalar
38        }
39    }
40
41    #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
42    {
43        &TakeKernelScalar
44    }
45});
46
47trait TakeImpl: Send + Sync {
48    fn take(
49        &self,
50        array: ArrayView<'_, Primitive>,
51        indices: ArrayView<'_, Primitive>,
52        validity: Validity,
53    ) -> VortexResult<ArrayRef>;
54}
55
56struct TakeKernelScalar;
57
58impl TakeImpl for TakeKernelScalar {
59    fn take(
60        &self,
61        array: ArrayView<'_, Primitive>,
62        indices: ArrayView<'_, Primitive>,
63        validity: Validity,
64    ) -> VortexResult<ArrayRef> {
65        match_each_native_ptype!(array.ptype(), |T| {
66            match_each_integer_ptype!(indices.ptype(), |I| {
67                let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
68                Ok(PrimitiveArray::new(values, validity).into_array())
69            })
70        })
71    }
72}
73
74impl TakeExecute for Primitive {
75    fn take(
76        array: ArrayView<'_, Primitive>,
77        indices: &ArrayRef,
78        ctx: &mut ExecutionCtx,
79    ) -> VortexResult<Option<ArrayRef>> {
80        let DType::Primitive(ptype, null) = indices.dtype() else {
81            vortex_bail!("Invalid indices dtype: {}", indices.dtype())
82        };
83
84        let unsigned_indices = if ptype.is_unsigned_int() {
85            indices.clone().execute::<PrimitiveArray>(ctx)?
86        } else {
87            // This will fail if all values cannot be converted to unsigned
88            indices
89                .clone()
90                .cast(DType::Primitive(ptype.to_unsigned(), *null))?
91                .execute::<PrimitiveArray>(ctx)?
92        };
93
94        let validity = array
95            .validity()?
96            .take(&unsigned_indices.clone().into_array())?;
97        // Delegate to the best kernel based on the target CPU
98        {
99            let unsigned_indices = unsigned_indices.as_view();
100            PRIMITIVE_TAKE_KERNEL
101                .take(array, unsigned_indices, validity)
102                .map(Some)
103        }
104    }
105}
106
107// Compiler may see this as unused based on enabled features
108#[inline(always)]
109fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
110    buffer: &[T],
111    indices: &[I],
112) -> Buffer<T> {
113    // NB: The simpler `indices.iter().map(|idx| buffer[idx.as_()]).collect()` generates suboptimal
114    // assembly where the buffer length is repeatedly loaded from the stack on each iteration.
115
116    let mut result = BufferMut::with_capacity(indices.len());
117    let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
118
119    // This explicit loop with pointer writes keeps the length in a register and avoids per-element
120    // capacity checks from `push()`.
121    for (i, idx) in indices.iter().enumerate() {
122        // SAFETY: We reserved `indices.len()` capacity, so `ptr.add(i)` is valid.
123        unsafe { ptr.add(i).write(buffer[idx.as_()]) };
124    }
125
126    // SAFETY: We just wrote exactly `indices.len()` elements.
127    unsafe { result.set_len(indices.len()) };
128    result.freeze()
129}
130
131#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
132#[cfg(test)]
133mod test {
134    use rstest::rstest;
135    use vortex_buffer::buffer;
136    use vortex_error::VortexExpect;
137
138    use crate::IntoArray;
139    use crate::LEGACY_SESSION;
140    use crate::VortexSessionExecute;
141    use crate::arrays::BoolArray;
142    use crate::arrays::PrimitiveArray;
143    use crate::arrays::primitive::compute::take::take_primitive_scalar;
144    use crate::compute::conformance::take::test_take_conformance;
145    use crate::scalar::Scalar;
146    use crate::validity::Validity;
147
148    #[test]
149    fn test_take() {
150        let a = vec![1i32, 2, 3, 4, 5];
151        let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
152        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
153    }
154
155    #[test]
156    fn test_take_with_null_indices() {
157        let values = PrimitiveArray::new(
158            buffer![1i32, 2, 3, 4, 5],
159            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
160        );
161        let indices = PrimitiveArray::new(
162            buffer![0, 3, 4],
163            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
164        );
165        let actual = values.take(indices.into_array()).unwrap();
166        assert_eq!(
167            actual
168                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
169                .vortex_expect("no fail"),
170            Scalar::from(Some(1))
171        );
172        // position 3 is null
173        assert_eq!(
174            actual
175                .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
176                .vortex_expect("no fail"),
177            Scalar::null_native::<i32>()
178        );
179        // the third index is null
180        assert_eq!(
181            actual
182                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
183                .vortex_expect("no fail"),
184            Scalar::null_native::<i32>()
185        );
186    }
187
188    #[rstest]
189    #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
190    #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
191    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
192    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
193    #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
194    #[case(PrimitiveArray::new(
195        buffer![0, 1, 2, 3, 4, 5],
196        Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
197    ))]
198    #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
199    fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
200        test_take_conformance(&array.into_array());
201    }
202}