vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_dtype::DType;
14use vortex_dtype::IntegerPType;
15use vortex_dtype::NativePType;
16use vortex_dtype::match_each_integer_ptype;
17use vortex_dtype::match_each_native_ptype;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20
21use crate::Array;
22use crate::ArrayRef;
23use crate::IntoArray;
24use crate::ToCanonical;
25use crate::arrays::PrimitiveVTable;
26use crate::arrays::primitive::PrimitiveArray;
27use crate::compute::TakeKernel;
28use crate::compute::TakeKernelAdapter;
29use crate::compute::cast;
30use crate::register_kernel;
31use crate::validity::Validity;
32use crate::vtable::ValidityHelper;
33
34static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
37 cfg_if::cfg_if! {
38 if #[cfg(vortex_nightly)] {
39 &portable::TakeKernelPortableSimd
41 } else if #[cfg(target_arch = "x86_64")] {
42 if is_x86_feature_detected!("avx2") {
45 &avx2::TakeKernelAVX2
46 } else {
47 &TakeKernelScalar
48 }
49 } else {
50 &TakeKernelScalar
52 }
53 }
54});
55
56trait TakeImpl: Send + Sync {
57 fn take(
58 &self,
59 array: &PrimitiveArray,
60 indices: &PrimitiveArray,
61 validity: Validity,
62 ) -> VortexResult<ArrayRef>;
63}
64
65#[allow(unused)]
66struct TakeKernelScalar;
67
68impl TakeImpl for TakeKernelScalar {
69 fn take(
70 &self,
71 array: &PrimitiveArray,
72 indices: &PrimitiveArray,
73 validity: Validity,
74 ) -> VortexResult<ArrayRef> {
75 match_each_native_ptype!(array.ptype(), |T| {
76 match_each_integer_ptype!(indices.ptype(), |I| {
77 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
78 Ok(PrimitiveArray::new(values, validity).into_array())
79 })
80 })
81 }
82}
83
84impl TakeKernel for PrimitiveVTable {
85 fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
86 let DType::Primitive(ptype, null) = indices.dtype() else {
87 vortex_bail!("Invalid indices dtype: {}", indices.dtype())
88 };
89
90 let unsigned_indices = if ptype.is_unsigned_int() {
91 indices.to_primitive()
92 } else {
93 cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive()
95 };
96
97 let validity = array.validity().take(unsigned_indices.as_ref())?;
98 PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity)
100 }
101}
102
103register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
104
105#[allow(unused)]
107#[inline(always)]
108fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
109 indices.iter().map(|idx| array[idx.as_()]).collect()
110}
111
112#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
113#[cfg(test)]
114mod test {
115 use rstest::rstest;
116 use vortex_buffer::buffer;
117 use vortex_scalar::Scalar;
118
119 use crate::Array;
120 use crate::IntoArray;
121 use crate::arrays::BoolArray;
122 use crate::arrays::PrimitiveArray;
123 use crate::arrays::primitive::compute::take::take_primitive_scalar;
124 use crate::compute::conformance::take::test_take_conformance;
125 use crate::compute::take;
126 use crate::validity::Validity;
127
128 #[test]
129 fn test_take() {
130 let a = vec![1i32, 2, 3, 4, 5];
131 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
132 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
133 }
134
135 #[test]
136 fn test_take_with_null_indices() {
137 let values = PrimitiveArray::new(
138 buffer![1i32, 2, 3, 4, 5],
139 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
140 );
141 let indices = PrimitiveArray::new(
142 buffer![0, 3, 4],
143 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
144 );
145 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
146 assert_eq!(actual.scalar_at(0), Scalar::from(Some(1)));
147 assert_eq!(actual.scalar_at(1), Scalar::null_typed::<i32>());
149 assert_eq!(actual.scalar_at(2), Scalar::null_typed::<i32>());
151 }
152
153 #[rstest]
154 #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
155 #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
156 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
157 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
158 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
159 #[case(PrimitiveArray::new(
160 buffer![0, 1, 2, 3, 4, 5],
161 Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
162 ))]
163 #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
164 fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
165 test_take_conformance(array.as_ref());
166 }
167}