vortex_compute/take/slice/mod.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Take function implementations on slices.
5
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_dtype::UnsignedPType;
9
10use crate::take::Take;
11
12pub mod avx2;
13pub mod portable;
14
15/// Specialized implementation for non-nullable indices.
16impl<T: Copy, I: UnsignedPType> Take<[I]> for &[T] {
17 type Output = Buffer<T>;
18
19 fn take(self, indices: &[I]) -> Buffer<T> {
20 // TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`.
21 /*
22
23 #[cfg(vortex_nightly)]
24 {
25 return portable::take_portable(self, indices);
26 }
27
28 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
29 {
30 if is_x86_feature_detected!("avx2") {
31 // SAFETY: We just checked that the AVX2 feature in enabled.
32 return unsafe { avx2::take_avx2(self, indices) };
33 }
34 }
35
36 */
37
38 #[allow(unreachable_code, reason = "`vortex_nightly` path returns early")]
39 take_scalar(self, indices)
40 }
41}
42
43#[allow(
44 unused,
45 reason = "Compiler may see this as unused based on enabled features"
46)]
47fn take_scalar<T: Copy, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
48 // NB: The simpler `indices.iter().map(|idx| buff1er[idx.as_()]).collect()` generates suboptimal
49 // assembly where the buffer length is repeatedly loaded from the stack on each iteration.
50
51 let mut result = BufferMut::with_capacity(indices.len());
52 let ptr = result.spare_capacity_mut().as_mut_ptr().cast::<T>();
53
54 // This explicit loop with pointer writes keeps the length in a register and avoids per-element
55 // capacity checks from `push()`.
56 for (i, idx) in indices.iter().enumerate() {
57 // SAFETY: We reserved `indices.len()` capacity, so `ptr.add(i)` is valid.
58 unsafe { ptr.add(i).write(buffer[idx.as_()]) };
59 }
60
61 // SAFETY: We just wrote exactly `indices.len()` elements.
62 unsafe { result.set_len(indices.len()) };
63 result.freeze()
64}