1use num_traits::{AsPrimitive, NumCast};
5use vortex_array::arrays::PrimitiveArray;
6use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
7use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
8use vortex_array::validity::Validity;
9use vortex_array::vtable::ValidityHelper;
10use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
11use vortex_buffer::Buffer;
12use vortex_dtype::match_each_integer_ptype;
13use vortex_error::{VortexResult, vortex_bail};
14
15use crate::{RunEndArray, RunEndVTable};
16
17impl TakeKernel for RunEndVTable {
18 #[allow(clippy::cast_possible_truncation)]
19 fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20 let primitive_indices = indices.to_primitive()?;
21
22 let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
23 primitive_indices
24 .as_slice::<P>()
25 .iter()
26 .copied()
27 .map(|idx| {
28 let usize_idx = idx as usize;
29 if usize_idx >= array.len() {
30 vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
31 }
32 Ok(usize_idx)
33 })
34 .collect::<VortexResult<Vec<_>>>()?
35 });
36
37 take_indices_unchecked(array, &checked_indices, primitive_indices.validity())
38 }
39}
40
41register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
42
43pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
45 array: &RunEndArray,
46 indices: &[T],
47 validity: &Validity,
48) -> VortexResult<ArrayRef> {
49 let ends = array.ends().to_primitive()?;
50 let ends_len = ends.len();
51
52 let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
54 let end_slices = ends.as_slice::<I>();
55 let buffer = Buffer::from_trusted_len_iter(
56 indices
57 .iter()
58 .map(|idx| idx.as_() + array.offset())
59 .map(|idx| {
60 match <I as NumCast>::from(idx) {
61 Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
62 None => {
63 SearchResult::NotFound(ends_len)
65 }
66 }
67 })
68 .map(|result| result.to_ends_index(ends_len) as u64),
69 );
70
71 PrimitiveArray::new(buffer, validity.clone())
72 });
73
74 take(array.values(), physical_indices.as_ref())
75}
76
77#[cfg(test)]
78mod test {
79 use rstest::rstest;
80 use vortex_array::arrays::PrimitiveArray;
81 use vortex_array::compute::conformance::take::test_take_conformance;
82 use vortex_array::compute::take;
83 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
84 use vortex_dtype::{DType, Nullability, PType};
85 use vortex_scalar::{Scalar, ScalarValue};
86
87 use crate::RunEndArray;
88
89 fn ree_array() -> RunEndArray {
90 RunEndArray::encode(
91 PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(),
92 )
93 .unwrap()
94 }
95
96 #[test]
97 fn ree_take() {
98 let taken = take(
99 ree_array().as_ref(),
100 PrimitiveArray::from_iter([9, 8, 1, 3]).as_ref(),
101 )
102 .unwrap();
103 assert_eq!(
104 taken.to_primitive().unwrap().as_slice::<i32>(),
105 &[5, 5, 1, 4]
106 );
107 }
108
109 #[test]
110 fn ree_take_end() {
111 let taken = take(
112 ree_array().as_ref(),
113 PrimitiveArray::from_iter([11]).as_ref(),
114 )
115 .unwrap();
116 assert_eq!(taken.to_primitive().unwrap().as_slice::<i32>(), &[5]);
117 }
118
119 #[test]
120 #[should_panic]
121 fn ree_take_out_of_bounds() {
122 take(
123 ree_array().as_ref(),
124 PrimitiveArray::from_iter([12]).as_ref(),
125 )
126 .unwrap();
127 }
128
129 #[test]
130 fn sliced_take() {
131 let sliced = ree_array().slice(4, 9).unwrap();
132 let taken = take(
133 sliced.as_ref(),
134 PrimitiveArray::from_iter([1, 3, 4]).as_ref(),
135 )
136 .unwrap();
137
138 assert_eq!(taken.len(), 3);
139 assert_eq!(taken.scalar_at(0).unwrap(), 4.into());
140 assert_eq!(taken.scalar_at(1).unwrap(), 2.into());
141 assert_eq!(taken.scalar_at(2).unwrap(), 5.into());
142 }
143
144 #[test]
145 fn ree_take_nullable() {
146 let taken = take(
147 ree_array().as_ref(),
148 PrimitiveArray::from_option_iter([Some(1), None]).as_ref(),
149 )
150 .unwrap();
151
152 assert_eq!(
153 taken.scalar_at(0).unwrap(),
154 Scalar::new(
155 DType::Primitive(PType::I32, Nullability::Nullable),
156 ScalarValue::from(1i32)
157 )
158 );
159 assert_eq!(
160 taken.scalar_at(1).unwrap(),
161 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
162 );
163 }
164
165 #[rstest]
166 #[case(ree_array())]
167 #[case(RunEndArray::encode(
168 PrimitiveArray::from_iter([1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4]).into_array(),
169 ).unwrap())]
170 #[case(RunEndArray::encode(
171 PrimitiveArray::from_option_iter([
172 Some(10),
173 Some(10),
174 None,
175 None,
176 Some(20),
177 Some(20),
178 Some(20),
179 ])
180 .into_array(),
181 ).unwrap())]
182 #[case(RunEndArray::encode(PrimitiveArray::from_iter([42i32, 42, 42, 42, 42]).into_array())
183 .unwrap())]
184 #[case(RunEndArray::encode(
185 PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10]).into_array(),
186 ).unwrap())]
187 #[case({
188 let mut values = Vec::new();
189 for i in 0..20 {
190 for _ in 0..=i {
191 values.push(i);
192 }
193 }
194 RunEndArray::encode(PrimitiveArray::from_iter(values).into_array()).unwrap()
195 })]
196 fn test_take_runend_conformance(#[case] array: RunEndArray) {
197 test_take_conformance(array.as_ref());
198 }
199
200 #[rstest]
201 #[case(ree_array().slice(3, 6).unwrap())]
202 #[case({
203 let array = RunEndArray::encode(
204 PrimitiveArray::from_iter([1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).into_array(),
205 )
206 .unwrap();
207 array.slice(2, 8).unwrap()
208 })]
209 fn test_take_sliced_runend_conformance(#[case] sliced: ArrayRef) {
210 test_take_conformance(sliced.as_ref());
211 }
212}