vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
2mod avx2;
3
4#[cfg(feature = "nightly")]
5mod portable;
6
7use std::sync::LazyLock;
8
9use num_traits::AsPrimitive;
10use vortex_buffer::Buffer;
11use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
12use vortex_error::{VortexResult, vortex_bail};
13
14use crate::arrays::PrimitiveVTable;
15use crate::arrays::primitive::PrimitiveArray;
16use crate::compute::{TakeKernel, TakeKernelAdapter, cast};
17use crate::validity::Validity;
18use crate::vtable::ValidityHelper;
19use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
20
21static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
24 cfg_if::cfg_if! {
25 if #[cfg(feature = "nightly")] {
26 &portable::TakeKernelPortableSimd
28 } else if #[cfg(target_arch = "x86_64")] {
29 if is_x86_feature_detected!("avx2") {
32 &avx2::TakeKernelAVX2
33 } else {
34 &TakeKernelScalar
35 }
36 } else {
37 &TakeKernelScalar
39 }
40 }
41});
42
43trait TakeImpl: Send + Sync {
44 fn take(
45 &self,
46 array: &PrimitiveArray,
47 indices: &PrimitiveArray,
48 validity: Validity,
49 ) -> VortexResult<ArrayRef>;
50}
51
52#[allow(unused)]
53struct TakeKernelScalar;
54
55impl TakeImpl for TakeKernelScalar {
56 fn take(
57 &self,
58 array: &PrimitiveArray,
59 indices: &PrimitiveArray,
60 validity: Validity,
61 ) -> VortexResult<ArrayRef> {
62 match_each_native_ptype!(array.ptype(), |T| {
63 match_each_integer_ptype!(indices.ptype(), |I| {
64 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
65 Ok(PrimitiveArray::new(values, validity).into_array())
66 })
67 })
68 }
69}
70
71impl TakeKernel for PrimitiveVTable {
72 fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
73 let unsigned_indices = match indices.dtype() {
74 DType::Primitive(p, n) => {
75 if p.is_unsigned_int() {
76 indices.to_primitive()?
77 } else {
78 cast(indices, &DType::Primitive(p.to_unsigned(), *n))?.to_primitive()?
80 }
81 }
82 _ => vortex_bail!("Invalid indices dtype: {}", indices.dtype()),
83 };
84 let validity = array.validity().take(unsigned_indices.as_ref())?;
85 PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity)
87 }
88}
89
90register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
91
92#[allow(unused)]
94#[inline(always)]
95fn take_primitive_scalar<T: NativePType, I: NativePType + AsPrimitive<usize>>(
96 array: &[T],
97 indices: &[I],
98) -> Buffer<T> {
99 indices.iter().map(|idx| array[idx.as_()]).collect()
100}
101
102#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
103#[cfg(test)]
104mod test {
105 use vortex_buffer::buffer;
106 use vortex_scalar::Scalar;
107
108 use crate::arrays::primitive::compute::take::take_primitive_scalar;
109 use crate::arrays::{BoolArray, PrimitiveArray};
110 use crate::compute::take;
111 use crate::validity::Validity;
112 use crate::{Array, IntoArray};
113
114 #[test]
115 fn test_take() {
116 let a = vec![1i32, 2, 3, 4, 5];
117 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
118 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
119 }
120
121 #[test]
122 fn test_take_with_null_indices() {
123 let values = PrimitiveArray::new(
124 buffer![1i32, 2, 3, 4, 5],
125 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
126 );
127 let indices = PrimitiveArray::new(
128 buffer![0, 3, 4],
129 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
130 );
131 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
132 assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(1)));
133 assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<i32>());
135 assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<i32>());
137 }
138}