vortex_array/arrays/primitive/compute/take/
mod.rs1#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5mod avx2;
6
7#[cfg(vortex_nightly)]
8mod portable;
9
10use std::sync::LazyLock;
11
12use num_traits::AsPrimitive;
13use vortex_buffer::Buffer;
14use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
15use vortex_error::{VortexResult, vortex_bail};
16
17use crate::arrays::PrimitiveVTable;
18use crate::arrays::primitive::PrimitiveArray;
19use crate::compute::{TakeKernel, TakeKernelAdapter, cast};
20use crate::validity::Validity;
21use crate::vtable::ValidityHelper;
22use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
23
24static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(|| {
27 cfg_if::cfg_if! {
28 if #[cfg(vortex_nightly)] {
29 &portable::TakeKernelPortableSimd
31 } else if #[cfg(target_arch = "x86_64")] {
32 if is_x86_feature_detected!("avx2") {
35 &avx2::TakeKernelAVX2
36 } else {
37 &TakeKernelScalar
38 }
39 } else {
40 &TakeKernelScalar
42 }
43 }
44});
45
46trait TakeImpl: Send + Sync {
47 fn take(
48 &self,
49 array: &PrimitiveArray,
50 indices: &PrimitiveArray,
51 validity: Validity,
52 ) -> VortexResult<ArrayRef>;
53}
54
55#[allow(unused)]
56struct TakeKernelScalar;
57
58impl TakeImpl for TakeKernelScalar {
59 fn take(
60 &self,
61 array: &PrimitiveArray,
62 indices: &PrimitiveArray,
63 validity: Validity,
64 ) -> VortexResult<ArrayRef> {
65 match_each_native_ptype!(array.ptype(), |T| {
66 match_each_integer_ptype!(indices.ptype(), |I| {
67 let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
68 Ok(PrimitiveArray::new(values, validity).into_array())
69 })
70 })
71 }
72}
73
74impl TakeKernel for PrimitiveVTable {
75 fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
76 let unsigned_indices = match indices.dtype() {
77 DType::Primitive(p, n) => {
78 if p.is_unsigned_int() {
79 indices.to_primitive()?
80 } else {
81 cast(indices, &DType::Primitive(p.to_unsigned(), *n))?.to_primitive()?
83 }
84 }
85 _ => vortex_bail!("Invalid indices dtype: {}", indices.dtype()),
86 };
87 let validity = array.validity().take(unsigned_indices.as_ref())?;
88 PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity)
90 }
91}
92
93register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
94
95#[allow(unused)]
97#[inline(always)]
98fn take_primitive_scalar<T: NativePType, I: NativePType + AsPrimitive<usize>>(
99 array: &[T],
100 indices: &[I],
101) -> Buffer<T> {
102 indices.iter().map(|idx| array[idx.as_()]).collect()
103}
104
105#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
106#[cfg(test)]
107mod test {
108 use vortex_buffer::buffer;
109 use vortex_scalar::Scalar;
110
111 use crate::arrays::primitive::compute::take::take_primitive_scalar;
112 use crate::arrays::{BoolArray, PrimitiveArray};
113 use crate::compute::take;
114 use crate::validity::Validity;
115 use crate::{Array, IntoArray};
116
117 #[test]
118 fn test_take() {
119 let a = vec![1i32, 2, 3, 4, 5];
120 let result = take_primitive_scalar(&a, &[0, 0, 4, 2]);
121 assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
122 }
123
124 #[test]
125 fn test_take_with_null_indices() {
126 let values = PrimitiveArray::new(
127 buffer![1i32, 2, 3, 4, 5],
128 Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()),
129 );
130 let indices = PrimitiveArray::new(
131 buffer![0, 3, 4],
132 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
133 );
134 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
135 assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(1)));
136 assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<i32>());
138 assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<i32>());
140 }
141}