vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use vortex_buffer::Buffer;
13use vortex_dtype::{
14 DType, IntegerPType, NativePType, match_each_integer_ptype, match_each_native_ptype,
15};
16use vortex_error::{VortexResult, vortex_bail};
17
18use crate::arrays::PrimitiveVTable;
19use crate::arrays::primitive::PrimitiveArray;
20use crate::compute::{TakeKernel, TakeKernelAdapter, cast};
21use crate::validity::Validity;
22use crate::vtable::ValidityHelper;
23use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
24
25static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
28 cfg_if::cfg_if! {
29 if #[cfg(vortex_nightly)] {
30 &portable::TakeKernelPortableSimd
32 } else if #[cfg(target_arch = "x86_64")] {
33 if is_x86_feature_detected!("avx2") {
36 &avx2::TakeKernelAVX2
37 } else {
38 &TakeKernelScalar
39 }
40 } else {
41 &TakeKernelScalar
43 }
44 }
45});
46
47trait TakeImpl: Send + Sync {
48 fn take(
49 &self,
50 array: &PrimitiveArray,
51 indices: &PrimitiveArray,
52 validity: Validity,
53 ) -> VortexResult<ArrayRef>;
54}
55
56#[allow(unused)]
57struct TakeKernelScalar;
58
59impl TakeImpl for TakeKernelScalar {
60 fn take(
61 &self,
62 array: &PrimitiveArray,
63 indices: &PrimitiveArray,
64 validity: Validity,
65 ) -> VortexResult<ArrayRef> {
66 match_each_native_ptype!(array.ptype(), |T| {
67 match_each_integer_ptype!(indices.ptype(), |I| {
68 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
69 Ok(PrimitiveArray::new(values, validity).into_array())
70 })
71 })
72 }
73}
74
75impl TakeKernel for PrimitiveVTable {
76 fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
77 let DType::Primitive(ptype, null) = indices.dtype() else {
78 vortex_bail!("Invalid indices dtype: {}", indices.dtype())
79 };
80
81 let unsigned_indices = if ptype.is_unsigned_int() {
82 indices.to_primitive()
83 } else {
84 cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive()
86 };
87
88 let validity = array.validity().take(unsigned_indices.as_ref())?;
89 PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity)
91 }
92}
93
94register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
95
96#[allow(unused)]
98#[inline(always)]
99fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
100 indices.iter().map(|idx| array[idx.as_()]).collect()
101}
102
103#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
104#[cfg(test)]
105mod test {
106 use rstest::rstest;
107 use vortex_buffer::buffer;
108 use vortex_scalar::Scalar;
109
110 use crate::arrays::primitive::compute::take::take_primitive_scalar;
111 use crate::arrays::{BoolArray, PrimitiveArray};
112 use crate::compute::conformance::take::test_take_conformance;
113 use crate::compute::take;
114 use crate::validity::Validity;
115 use crate::{Array, IntoArray};
116
117 #[test]
118 fn test_take() {
119 let a = vec![1i32, 2, 3, 4, 5];
120 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
121 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
122 }
123
124 #[test]
125 fn test_take_with_null_indices() {
126 let values = PrimitiveArray::new(
127 buffer![1i32, 2, 3, 4, 5],
128 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
129 );
130 let indices = PrimitiveArray::new(
131 buffer![0, 3, 4],
132 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
133 );
134 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
135 assert_eq!(actual.scalar_at(0), Scalar::from(Some(1)));
136 assert_eq!(actual.scalar_at(1), Scalar::null_typed::<i32>());
138 assert_eq!(actual.scalar_at(2), Scalar::null_typed::<i32>());
140 }
141
142 #[rstest]
143 #[case(PrimitiveArray::new(buffer![42i32], Validity::NonNullable))]
144 #[case(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable))]
145 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable))]
146 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable))]
147 #[case(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid))]
148 #[case(PrimitiveArray::new(
149 buffer![0, 1, 2, 3, 4, 5],
150 Validity::Array(BoolArray::from_iter([true, false, true, false, true, true]).into_array()),
151 ))]
152 #[case(PrimitiveArray::from_option_iter([Some(1), None, Some(3), Some(4), None]))]
153 fn test_take_primitive_conformance(#[case] array: PrimitiveArray) {
154 test_take_conformance(array.as_ref());
155 }
156}