1use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::PrimitiveArray;
10use vortex_array::compute::TakeKernel;
11use vortex_array::compute::TakeKernelAdapter;
12use vortex_array::compute::take;
13use vortex_array::register_kernel;
14use vortex_array::search_sorted::SearchResult;
15use vortex_array::search_sorted::SearchSorted;
16use vortex_array::search_sorted::SearchSortedSide;
17use vortex_array::validity::Validity;
18use vortex_array::vtable::ValidityHelper;
19use vortex_buffer::Buffer;
20use vortex_dtype::match_each_integer_ptype;
21use vortex_error::VortexResult;
22use vortex_error::vortex_bail;
23
24use crate::RunEndArray;
25use crate::RunEndVTable;
26
27impl TakeKernel for RunEndVTable {
28 #[expect(
29 clippy::cast_possible_truncation,
30 reason = "index cast to usize inside macro"
31 )]
32 fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
33 let primitive_indices = indices.to_primitive();
34
35 let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
36 primitive_indices
37 .as_slice::<P>()
38 .iter()
39 .copied()
40 .map(|idx| {
41 let usize_idx = idx as usize;
42 if usize_idx >= array.len() {
43 vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
44 }
45 Ok(usize_idx)
46 })
47 .collect::<VortexResult<Vec<_>>>()?
48 });
49
50 take_indices_unchecked(array, &checked_indices, primitive_indices.validity())
51 }
52}
53
54register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
55
56pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
58 array: &RunEndArray,
59 indices: &[T],
60 validity: &Validity,
61) -> VortexResult<ArrayRef> {
62 let ends = array.ends().to_primitive();
63 let ends_len = ends.len();
64
65 let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
67 let end_slices = ends.as_slice::<I>();
68 let buffer = Buffer::from_trusted_len_iter(
69 indices
70 .iter()
71 .map(|idx| idx.as_() + array.offset())
72 .map(|idx| {
73 match <I as NumCast>::from(idx) {
74 Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
75 None => {
76 SearchResult::NotFound(ends_len)
78 }
79 }
80 })
81 .map(|result| result.to_ends_index(ends_len) as u64),
82 );
83
84 PrimitiveArray::new(buffer, validity.clone())
85 });
86
87 take(array.values(), physical_indices.as_ref())
88}
89
90#[cfg(test)]
91mod test {
92 use rstest::rstest;
93 use vortex_array::Array;
94 use vortex_array::ArrayRef;
95 use vortex_array::IntoArray;
96 use vortex_array::arrays::PrimitiveArray;
97 use vortex_array::assert_arrays_eq;
98 use vortex_array::compute::conformance::take::test_take_conformance;
99 use vortex_array::compute::take;
100 use vortex_buffer::buffer;
101
102 use crate::RunEndArray;
103
104 fn ree_array() -> RunEndArray {
105 RunEndArray::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()).unwrap()
106 }
107
108 #[test]
109 fn ree_take() {
110 let taken = take(
111 ree_array().as_ref(),
112 buffer![9, 8, 1, 3].into_array().as_ref(),
113 )
114 .unwrap();
115 let expected = PrimitiveArray::from_iter(vec![5i32, 5, 1, 4]).into_array();
116 assert_arrays_eq!(taken, expected);
117 }
118
119 #[test]
120 fn ree_take_end() {
121 let taken = take(ree_array().as_ref(), buffer![11].into_array().as_ref()).unwrap();
122 let expected = PrimitiveArray::from_iter(vec![5i32]).into_array();
123 assert_arrays_eq!(taken, expected);
124 }
125
126 #[test]
127 #[should_panic]
128 fn ree_take_out_of_bounds() {
129 take(ree_array().as_ref(), buffer![12].into_array().as_ref()).unwrap();
130 }
131
132 #[test]
133 fn sliced_take() {
134 let sliced = ree_array().slice(4..9);
135 let taken = take(sliced.as_ref(), buffer![1, 3, 4].into_array().as_ref()).unwrap();
136
137 let expected = PrimitiveArray::from_iter(vec![4i32, 2, 5]).into_array();
138 assert_arrays_eq!(taken, expected);
139 }
140
141 #[test]
142 fn ree_take_nullable() {
143 let taken = take(
144 ree_array().as_ref(),
145 PrimitiveArray::from_option_iter([Some(1), None]).as_ref(),
146 )
147 .unwrap();
148
149 let expected = PrimitiveArray::from_option_iter([Some(1i32), None]);
150 assert_arrays_eq!(taken, expected.to_array());
151 }
152
153 #[rstest]
154 #[case(ree_array())]
155 #[case(RunEndArray::encode(
156 buffer![1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4].into_array(),
157 ).unwrap())]
158 #[case(RunEndArray::encode(
159 PrimitiveArray::from_option_iter([
160 Some(10),
161 Some(10),
162 None,
163 None,
164 Some(20),
165 Some(20),
166 Some(20),
167 ])
168 .into_array(),
169 ).unwrap())]
170 #[case(RunEndArray::encode(buffer![42i32, 42, 42, 42, 42].into_array())
171 .unwrap())]
172 #[case(RunEndArray::encode(
173 buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(),
174 ).unwrap())]
175 #[case({
176 let mut values = Vec::new();
177 for i in 0..20 {
178 for _ in 0..=i {
179 values.push(i);
180 }
181 }
182 RunEndArray::encode(PrimitiveArray::from_iter(values).into_array()).unwrap()
183 })]
184 fn test_take_runend_conformance(#[case] array: RunEndArray) {
185 test_take_conformance(array.as_ref());
186 }
187
188 #[rstest]
189 #[case(ree_array().slice(3..6))]
190 #[case({
191 let array = RunEndArray::encode(
192 buffer![1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3].into_array(),
193 )
194 .unwrap();
195 array.slice(2..8)
196 })]
197 fn test_take_sliced_runend_conformance(#[case] sliced: ArrayRef) {
198 test_take_conformance(sliced.as_ref());
199 }
200}