vortex_array/arrays/primitive/compute/
take.rs

1use std::simd;
2
3use num_traits::AsPrimitive;
4use simd::num::SimdUint;
5use vortex_buffer::{Alignment, Buffer, BufferMut};
6use vortex_dtype::{
7    NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
8    match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
9};
10use vortex_error::VortexResult;
11
12use crate::arrays::PrimitiveVTable;
13use crate::arrays::primitive::PrimitiveArray;
14use crate::compute::{TakeKernel, TakeKernelAdapter};
15use crate::vtable::ValidityHelper;
16use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
17
18impl TakeKernel for PrimitiveVTable {
19    #[allow(clippy::cognitive_complexity)]
20    fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21        let indices = indices.to_primitive()?;
22
23        if array.ptype() != PType::F16
24            && indices.dtype().is_unsigned_int()
25            && indices.all_valid()?
26            && array.all_valid()?
27        {
28            // TODO(alex): handle nullable codes & values
29            match_each_unsigned_integer_ptype!(indices.ptype(), |C| {
30                match_each_native_simd_ptype!(array.ptype(), |V| {
31                    // SIMD types larger than the SIMD register size are beneficial for
32                    // performance as this leads to better instruction level parallelism.
33                    let decoded = take_primitive_simd::<C, V, 64>(
34                        indices.as_slice(),
35                        array.as_slice(),
36                        array.dtype().nullability() | indices.dtype().nullability(),
37                    );
38
39                    return Ok(decoded.into_array()) as VortexResult<ArrayRef>;
40                })
41            });
42        }
43
44        // TODO(joe): if the true count of take indices validity is low, only take array values with
45        // valid indices.
46        let validity = array.validity().take(indices.as_ref())?;
47        match_each_native_ptype!(array.ptype(), |T| {
48            match_each_integer_ptype!(indices.ptype(), |I| {
49                let values = take_primitive(array.as_slice::<T>(), indices.as_slice::<I>());
50                Ok(PrimitiveArray::new(values, validity).into_array())
51            })
52        })
53    }
54}
55
56register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
57
58fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
59    array: &[T],
60    indices: &[I],
61) -> Buffer<T> {
62    indices.iter().map(|idx| array[idx.as_()]).collect()
63}
64
65/// Takes elements from an array using SIMD indexing.
66///
67/// # Type Parameters
68/// * `C` - Index type
69/// * `V` - Value type
70/// * `LANE_COUNT` - Number of SIMD lanes to process in parallel
71///
72/// # Parameters
73/// * `indices` - Indices to gather values from
74/// * `values` - Source values to index
75/// * `nullability` - Nullability of the resulting array
76///
77/// # Returns
78/// A `PrimitiveArray` containing the gathered values where each index has been replaced with
79/// the corresponding value from the source array.
80fn take_primitive_simd<I, V, const LANE_COUNT: usize>(
81    indices: &[I],
82    values: &[V],
83    nullability: Nullability,
84) -> PrimitiveArray
85where
86    I: simd::SimdElement + AsPrimitive<usize>,
87    V: simd::SimdElement + NativePType,
88    simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
89    simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
90{
91    let indices_len = indices.len();
92
93    let mut buffer = BufferMut::<V>::with_capacity_aligned(
94        indices_len,
95        Alignment::of::<simd::Simd<V, LANE_COUNT>>(),
96    );
97
98    let buf_slice = buffer.spare_capacity_mut();
99
100    for chunk_idx in 0..(indices_len / LANE_COUNT) {
101        let offset = chunk_idx * LANE_COUNT;
102        let mask = simd::Mask::from_bitmask(u64::MAX);
103        let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
104
105        unsafe {
106            let selection = simd::Simd::gather_select_unchecked(
107                values,
108                mask,
109                codes_chunk.cast::<usize>(),
110                simd::Simd::<V, LANE_COUNT>::default(),
111            );
112
113            selection.store_select_ptr(buf_slice.as_mut_ptr().add(offset) as *mut V, mask.cast());
114        }
115    }
116
117    for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
118        unsafe {
119            buf_slice
120                .get_unchecked_mut(idx)
121                .write(values[indices[idx].as_()]);
122        }
123    }
124
125    unsafe {
126        buffer.set_len(indices_len);
127    }
128
129    PrimitiveArray::new(buffer.freeze(), nullability.into())
130}
131
132#[cfg(test)]
133mod test {
134    use vortex_buffer::buffer;
135    use vortex_scalar::Scalar;
136
137    use crate::arrays::primitive::compute::take::take_primitive;
138    use crate::arrays::{BoolArray, PrimitiveArray};
139    use crate::compute::take;
140    use crate::validity::Validity;
141    use crate::{Array, IntoArray};
142
143    #[test]
144    fn test_take() {
145        let a = vec![1i32, 2, 3, 4, 5];
146        let result = take_primitive(&a, &[0, 0, 4, 2]);
147        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
148    }
149
150    #[test]
151    fn test_take_with_null_indices() {
152        let values = PrimitiveArray::new(
153            buffer![1i32, 2, 3, 4, 5],
154            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
155        );
156        let indices = PrimitiveArray::new(
157            buffer![0, 3, 4],
158            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
159        );
160        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
161        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(1)));
162        // position 3 is null
163        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<i32>());
164        // the third index is null
165        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<i32>());
166    }
167}