vortex_array/arrays/primitive/compute/
take.rs

1use std::simd;
2
3use num_traits::AsPrimitive;
4use simd::num::SimdUint;
5use vortex_buffer::{Alignment, Buffer, BufferMut};
6use vortex_dtype::{
7    NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
8    match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
9};
10use vortex_error::{VortexResult, vortex_err};
11use vortex_mask::Mask;
12
13use crate::arrays::PrimitiveEncoding;
14use crate::arrays::primitive::PrimitiveArray;
15use crate::builders::{ArrayBuilder, PrimitiveBuilder};
16use crate::compute::TakeFn;
17use crate::variants::PrimitiveArrayTrait;
18use crate::{Array, ArrayRef, ToCanonical};
19
20impl TakeFn<&PrimitiveArray> for PrimitiveEncoding {
21    #[allow(clippy::cognitive_complexity)]
22    fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23        let indices = indices.to_primitive()?;
24        let validity = array.validity().take(&indices)?;
25
26        if array.ptype() != PType::F16
27            && indices.dtype().is_unsigned_int()
28            && indices.all_valid()?
29            && array.all_valid()?
30        {
31            // TODO(alex): handle nullable codes & values
32            match_each_unsigned_integer_ptype!(indices.ptype(), |$C| {
33                match_each_native_simd_ptype!(array.ptype(), |$V| {
34                    // SIMD types larger than the SIMD register size are beneficial for
35                    // performance as this leads to better instruction level parallelism.
36                    let decoded = take_primitive_simd::<$C, $V, 64>(
37                        indices.as_slice(),
38                        array.as_slice(),
39                        array.dtype().nullability() | indices.dtype().nullability(),
40                    );
41
42                    return Ok(decoded.into_array()) as VortexResult<ArrayRef>;
43                })
44            });
45        }
46
47        match_each_native_ptype!(array.ptype(), |$T| {
48            match_each_integer_ptype!(indices.ptype(), |$I| {
49                let values = take_primitive(array.as_slice::<$T>(), indices.as_slice::<$I>());
50                Ok(PrimitiveArray::new(values, validity).into_array())
51            })
52        })
53    }
54
55    fn take_into(
56        &self,
57        array: &PrimitiveArray,
58        indices: &dyn Array,
59        builder: &mut dyn ArrayBuilder,
60    ) -> VortexResult<()> {
61        let indices = indices.to_primitive()?;
62        let mask = array.validity().take(&indices)?.to_mask(indices.len())?;
63
64        match_each_native_ptype!(array.ptype(), |$T| {
65            match_each_integer_ptype!(indices.ptype(), |$I| {
66                take_into_impl(array.as_slice::<$T>(), indices.as_slice::<$I>(), mask, builder)
67            })
68        })
69    }
70}
71
72fn take_into_impl<T: NativePType, I: NativePType + AsPrimitive<usize>>(
73    array: &[T],
74    indices: &[I],
75    mask: Mask,
76    builder: &mut dyn ArrayBuilder,
77) -> VortexResult<()> {
78    assert_eq!(indices.len(), mask.len());
79
80    let builder = builder
81        .as_any_mut()
82        .downcast_mut::<PrimitiveBuilder<T>>()
83        .ok_or_else(|| {
84            vortex_err!(
85                "Failed to downcast builder to PrimitiveBuilder<{}>",
86                T::PTYPE
87            )
88        })?;
89    builder.extend_with_iterator(indices.iter().map(|idx| array[idx.as_()]), mask);
90    Ok(())
91}
92
93fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
94    array: &[T],
95    indices: &[I],
96) -> Buffer<T> {
97    indices.iter().map(|idx| array[idx.as_()]).collect()
98}
99
100/// Takes elements from an array using SIMD indexing.
101///
102/// # Type Parameters
103/// * `C` - Index type
104/// * `V` - Value type
105/// * `LANE_COUNT` - Number of SIMD lanes to process in parallel
106///
107/// # Parameters
108/// * `indices` - Indices to gather values from
109/// * `values` - Source values to index
110/// * `nullability` - Nullability of the resulting array
111///
112/// # Returns
113/// A `PrimitiveArray` containing the gathered values where each index has been replaced with
114/// the corresponding value from the source array.
115fn take_primitive_simd<I, V, const LANE_COUNT: usize>(
116    indices: &[I],
117    values: &[V],
118    nullability: Nullability,
119) -> PrimitiveArray
120where
121    I: simd::SimdElement + AsPrimitive<usize>,
122    V: simd::SimdElement + NativePType,
123    simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
124    simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
125{
126    let indices_len = indices.len();
127
128    let mut buffer = BufferMut::<V>::with_capacity_aligned(
129        indices_len,
130        Alignment::of::<simd::Simd<V, LANE_COUNT>>(),
131    );
132
133    let buf_slice = buffer.spare_capacity_mut();
134
135    for chunk_idx in 0..(indices_len / LANE_COUNT) {
136        let offset = chunk_idx * LANE_COUNT;
137        let mask = simd::Mask::from_bitmask(u64::MAX);
138        let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
139
140        unsafe {
141            let selection = simd::Simd::gather_select_unchecked(
142                values,
143                mask,
144                codes_chunk.cast::<usize>(),
145                simd::Simd::<V, LANE_COUNT>::default(),
146            );
147
148            selection.store_select_ptr(buf_slice.as_mut_ptr().add(offset) as *mut V, mask.cast());
149        }
150    }
151
152    for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
153        unsafe {
154            buf_slice
155                .get_unchecked_mut(idx)
156                .write(values[indices[idx].as_()]);
157        }
158    }
159
160    unsafe {
161        buffer.set_len(indices_len);
162    }
163
164    PrimitiveArray::new(buffer.freeze(), nullability.into())
165}
166
167#[cfg(test)]
168mod test {
169    use vortex_buffer::buffer;
170    use vortex_dtype::Nullability;
171    use vortex_scalar::Scalar;
172
173    use crate::array::Array;
174    use crate::arrays::primitive::compute::take::take_primitive;
175    use crate::arrays::{BoolArray, PrimitiveArray};
176    use crate::builders::{ArrayBuilder as _, PrimitiveBuilder};
177    use crate::compute::{scalar_at, take, take_into};
178    use crate::validity::Validity;
179
180    #[test]
181    fn test_take() {
182        let a = vec![1i32, 2, 3, 4, 5];
183        let result = take_primitive(&a, &[0, 0, 4, 2]);
184        assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
185    }
186
187    #[test]
188    fn test_take_with_null_indices() {
189        let values = PrimitiveArray::new(
190            buffer![1i32, 2, 3, 4, 5],
191            Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
192        );
193        let indices = PrimitiveArray::new(
194            buffer![0, 3, 4],
195            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
196        );
197        let actual = take(&values, &indices).unwrap();
198        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
199        // position 3 is null
200        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<i32>());
201        // the third index is null
202        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
203    }
204
205    #[test]
206    fn test_take_into() {
207        let values = PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable);
208        let all_valid_indices = PrimitiveArray::new(
209            buffer![0, 3, 4],
210            Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
211        );
212        let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
213        take_into(&values, &all_valid_indices, &mut builder).unwrap();
214        let actual = builder.finish();
215        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
216        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(4)));
217        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::from(Some(5)));
218
219        let mixed_valid_indices = PrimitiveArray::new(
220            buffer![0, 3, 4],
221            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
222        );
223        let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
224        take_into(&values, &mixed_valid_indices, &mut builder).unwrap();
225        let actual = builder.finish();
226        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
227        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(4)));
228        // the third index is null
229        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
230
231        let all_invalid_indices = PrimitiveArray::new(
232            buffer![0, 3, 4],
233            Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
234        );
235        let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
236        take_into(&values, &all_invalid_indices, &mut builder).unwrap();
237        let actual = builder.finish();
238        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::null_typed::<i32>());
239        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<i32>());
240        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
241
242        let non_null_indices = PrimitiveArray::new(buffer![0, 3, 4], Validity::NonNullable);
243        let mut builder = PrimitiveBuilder::<i32>::new(Nullability::NonNullable);
244        take_into(&values, &non_null_indices, &mut builder).unwrap();
245        let actual = builder.finish();
246        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(1));
247        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(4));
248        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::from(5));
249    }
250}