vortex_compute/take/slice/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Take function implementations on slices.
5
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_dtype::UnsignedPType;
9
10use crate::take::Take;
11
12pub mod avx2;
13pub mod portable;
14
15/// Specialized implementation for non-nullable indices.
16impl<T: Copy, I: UnsignedPType> Take<[I]> for &[T] {
17    type Output = Buffer<T>;
18
19    fn take(self, indices: &[I]) -> Buffer<T> {
20        // TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`.
21        /*
22
23        #[cfg(vortex_nightly)]
24        {
25            return portable::take_portable(self, indices);
26        }
27
28        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
29        {
30            if is_x86_feature_detected!("avx2") {
31                // SAFETY: We just checked that the AVX2 feature in enabled.
32                return unsafe { avx2::take_avx2(self, indices) };
33            }
34        }
35
36        */
37
38        #[allow(unreachable_code, reason = "`vortex_nightly` path returns early")]
39        take_scalar(self, indices)
40    }
41}
42
43#[allow(
44    unused,
45    reason = "Compiler may see this as unused based on enabled features"
46)]
47fn take_scalar<T: Copy, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
48    // NB: The simpler `indices.iter().map(|idx| buff1er[idx.as_()]).collect()` generates suboptimal
49    // assembly where the buffer length is repeatedly loaded from the stack on each iteration.
50
51    let mut result = BufferMut::with_capacity(indices.len());
52    let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
53
54    // This explicit loop with pointer writes keeps the length in a register and avoids per-element
55    // capacity checks from `push()`.
56    for (i, idx) in indices.iter().enumerate() {
57        // SAFETY: We reserved `indices.len()` capacity, so `ptr.add(i)` is valid.
58        unsafe { ptr.add(i).write(buffer[idx.as_()]) };
59    }
60
61    // SAFETY: We just wrote exactly `indices.len()` elements.
62    unsafe { result.set_len(indices.len()) };
63    result.freeze()
64}