vortex_sequence/compute/
take.rs1use num_traits::cast::NumCast;
5use vortex_array::arrays::{ConstantArray, PrimitiveArray};
6use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_buffer::Buffer;
10use vortex_dtype::{
11 DType, IntegerPType, NativePType, Nullability, match_each_integer_ptype,
12 match_each_native_ptype,
13};
14use vortex_error::{VortexExpect, VortexResult, vortex_panic};
15use vortex_mask::{AllOr, Mask};
16use vortex_scalar::Scalar;
17
18use crate::{SequenceArray, SequenceVTable};
19
20impl TakeKernel for SequenceVTable {
21 fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22 let mask = indices.validity_mask();
23 let indices = indices.to_primitive();
24 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
25
26 match_each_integer_ptype!(indices.ptype(), |T| {
27 let indices = indices.as_slice::<T>();
28 match_each_native_ptype!(array.ptype(), |S| {
29 let mul = array.multiplier().cast::<S>();
30 let base = array.base().cast::<S>();
31 Ok(take(
32 mul,
33 base,
34 indices,
35 mask,
36 result_nullability,
37 array.len(),
38 ))
39 })
40 })
41 }
42}
43
44fn take<T: IntegerPType, S: NativePType>(
45 mul: S,
46 base: S,
47 indices: &[T],
48 indices_mask: Mask,
49 result_nullability: Nullability,
50 len: usize,
51) -> ArrayRef {
52 match indices_mask.bit_buffer() {
53 AllOr::All => PrimitiveArray::new(
54 Buffer::from_trusted_len_iter(indices.iter().map(|i| {
55 if i.as_() >= len {
56 vortex_panic!(OutOfBounds: i.as_(), 0, len);
57 }
58 let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
59 base + i * mul
60 })),
61 Validity::from(result_nullability),
62 )
63 .into_array(),
64 AllOr::None => ConstantArray::new(
65 Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
66 indices.len(),
67 )
68 .into_array(),
69 AllOr::Some(b) => {
70 let buffer =
71 Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
72 if b.value(mask_index) {
73 if i.as_() >= len {
74 vortex_panic!(OutOfBounds: i.as_(), 0, len);
75 }
76
77 let i =
78 <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
79 base + i * mul
80 } else {
81 S::zero()
82 }
83 }));
84 PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
85 }
86 }
87}
88
89register_kernel!(TakeKernelAdapter(SequenceVTable).lift());
90
91#[cfg(test)]
92mod test {
93 use rstest::rstest;
94 use vortex_array::compute::take;
95 use vortex_dtype::Nullability;
96
97 use crate::SequenceArray;
98
99 #[rstest]
100 #[case::basic_sequence(SequenceArray::typed_new(
101 0i32,
102 1i32,
103 Nullability::NonNullable,
104 10
105 ).unwrap())]
106 #[case::sequence_with_multiplier(SequenceArray::typed_new(
107 10i32,
108 5i32,
109 Nullability::Nullable,
110 20
111 ).unwrap())]
112 #[case::sequence_i64(SequenceArray::typed_new(
113 100i64,
114 10i64,
115 Nullability::NonNullable,
116 50
117 ).unwrap())]
118 #[case::sequence_u32(SequenceArray::typed_new(
119 0u32,
120 2u32,
121 Nullability::NonNullable,
122 100
123 ).unwrap())]
124 #[case::sequence_negative_step(SequenceArray::typed_new(
125 1000i32,
126 -10i32,
127 Nullability::Nullable,
128 30
129 ).unwrap())]
130 #[case::sequence_constant(SequenceArray::typed_new(
131 42i32,
132 0i32, Nullability::Nullable,
134 15
135 ).unwrap())]
136 #[case::sequence_i16(SequenceArray::typed_new(
137 -100i16,
138 3i16,
139 Nullability::NonNullable,
140 25
141 ).unwrap())]
142 #[case::sequence_large(SequenceArray::typed_new(
143 0i64,
144 1i64,
145 Nullability::Nullable,
146 1000
147 ).unwrap())]
148 fn test_take_conformance(#[case] sequence: SequenceArray) {
149 use vortex_array::compute::conformance::take::test_take_conformance;
150 test_take_conformance(sequence.as_ref());
151 }
152
153 #[test]
154 #[should_panic(expected = "index 20 out of bounds")]
155 fn test_bounds_check() {
156 let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
157 let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
158 let _array = take(array.as_ref(), indices.as_ref()).unwrap();
159 }
160}