vortex_sequence/compute/
take.rs1use num_traits::cast::NumCast;
5use vortex_array::Array;
6use vortex_array::ArrayRef;
7use vortex_array::IntoArray;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::compute::TakeKernel;
12use vortex_array::compute::TakeKernelAdapter;
13use vortex_array::register_kernel;
14use vortex_array::validity::Validity;
15use vortex_buffer::Buffer;
16use vortex_dtype::DType;
17use vortex_dtype::IntegerPType;
18use vortex_dtype::NativePType;
19use vortex_dtype::Nullability;
20use vortex_dtype::match_each_integer_ptype;
21use vortex_dtype::match_each_native_ptype;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_error::vortex_panic;
25use vortex_mask::AllOr;
26use vortex_mask::Mask;
27use vortex_scalar::Scalar;
28
29use crate::SequenceArray;
30use crate::SequenceVTable;
31
32impl TakeKernel for SequenceVTable {
33 fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
34 let mask = indices.validity_mask();
35 let indices = indices.to_primitive();
36 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
37
38 match_each_integer_ptype!(indices.ptype(), |T| {
39 let indices = indices.as_slice::<T>();
40 match_each_native_ptype!(array.ptype(), |S| {
41 let mul = array.multiplier().cast::<S>();
42 let base = array.base().cast::<S>();
43 Ok(take(
44 mul,
45 base,
46 indices,
47 mask,
48 result_nullability,
49 array.len(),
50 ))
51 })
52 })
53 }
54}
55
56fn take<T: IntegerPType, S: NativePType>(
57 mul: S,
58 base: S,
59 indices: &[T],
60 indices_mask: Mask,
61 result_nullability: Nullability,
62 len: usize,
63) -> ArrayRef {
64 match indices_mask.bit_buffer() {
65 AllOr::All => PrimitiveArray::new(
66 Buffer::from_trusted_len_iter(indices.iter().map(|i| {
67 if i.as_() >= len {
68 vortex_panic!(OutOfBounds: i.as_(), 0, len);
69 }
70 let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
71 base + i * mul
72 })),
73 Validity::from(result_nullability),
74 )
75 .into_array(),
76 AllOr::None => ConstantArray::new(
77 Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
78 indices.len(),
79 )
80 .into_array(),
81 AllOr::Some(b) => {
82 let buffer =
83 Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
84 if b.value(mask_index) {
85 if i.as_() >= len {
86 vortex_panic!(OutOfBounds: i.as_(), 0, len);
87 }
88
89 let i =
90 <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
91 base + i * mul
92 } else {
93 S::zero()
94 }
95 }));
96 PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
97 }
98 }
99}
100
101register_kernel!(TakeKernelAdapter(SequenceVTable).lift());
102
103#[cfg(test)]
104mod test {
105 use rstest::rstest;
106 use vortex_array::compute::take;
107 use vortex_dtype::Nullability;
108
109 use crate::SequenceArray;
110
111 #[rstest]
112 #[case::basic_sequence(SequenceArray::typed_new(
113 0i32,
114 1i32,
115 Nullability::NonNullable,
116 10
117 ).unwrap())]
118 #[case::sequence_with_multiplier(SequenceArray::typed_new(
119 10i32,
120 5i32,
121 Nullability::Nullable,
122 20
123 ).unwrap())]
124 #[case::sequence_i64(SequenceArray::typed_new(
125 100i64,
126 10i64,
127 Nullability::NonNullable,
128 50
129 ).unwrap())]
130 #[case::sequence_u32(SequenceArray::typed_new(
131 0u32,
132 2u32,
133 Nullability::NonNullable,
134 100
135 ).unwrap())]
136 #[case::sequence_negative_step(SequenceArray::typed_new(
137 1000i32,
138 -10i32,
139 Nullability::Nullable,
140 30
141 ).unwrap())]
142 #[case::sequence_constant(SequenceArray::typed_new(
143 42i32,
144 0i32, Nullability::Nullable,
146 15
147 ).unwrap())]
148 #[case::sequence_i16(SequenceArray::typed_new(
149 -100i16,
150 3i16,
151 Nullability::NonNullable,
152 25
153 ).unwrap())]
154 #[case::sequence_large(SequenceArray::typed_new(
155 0i64,
156 1i64,
157 Nullability::Nullable,
158 1000
159 ).unwrap())]
160 fn test_take_conformance(#[case] sequence: SequenceArray) {
161 use vortex_array::compute::conformance::take::test_take_conformance;
162 test_take_conformance(sequence.as_ref());
163 }
164
165 #[test]
166 #[should_panic(expected = "index 20 out of bounds")]
167 fn test_bounds_check() {
168 let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
169 let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
170 let _array = take(array.as_ref(), indices.as_ref()).unwrap();
171 }
172}