1use num_traits::{AsPrimitive, NumCast};
5use vortex_array::arrays::PrimitiveArray;
6use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
7use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
8use vortex_array::validity::Validity;
9use vortex_array::vtable::ValidityHelper;
10use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
11use vortex_buffer::Buffer;
12use vortex_dtype::match_each_integer_ptype;
13use vortex_error::{VortexResult, vortex_bail};
14
15use crate::{RunEndArray, RunEndVTable};
16
17impl TakeKernel for RunEndVTable {
18 #[allow(clippy::cast_possible_truncation)]
19 fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20 let primitive_indices = indices.to_primitive();
21
22 let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
23 primitive_indices
24 .as_slice::<P>()
25 .iter()
26 .copied()
27 .map(|idx| {
28 let usize_idx = idx as usize;
29 if usize_idx >= array.len() {
30 vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
31 }
32 Ok(usize_idx)
33 })
34 .collect::<VortexResult<Vec<_>>>()?
35 });
36
37 take_indices_unchecked(array, &checked_indices, primitive_indices.validity())
38 }
39}
40
41register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
42
43pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
45 array: &RunEndArray,
46 indices: &[T],
47 validity: &Validity,
48) -> VortexResult<ArrayRef> {
49 let ends = array.ends().to_primitive();
50 let ends_len = ends.len();
51
52 let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
54 let end_slices = ends.as_slice::<I>();
55 let buffer = Buffer::from_trusted_len_iter(
56 indices
57 .iter()
58 .map(|idx| idx.as_() + array.offset())
59 .map(|idx| {
60 match <I as NumCast>::from(idx) {
61 Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
62 None => {
63 SearchResult::NotFound(ends_len)
65 }
66 }
67 })
68 .map(|result| result.to_ends_index(ends_len) as u64),
69 );
70
71 PrimitiveArray::new(buffer, validity.clone())
72 });
73
74 take(array.values(), physical_indices.as_ref())
75}
76
77#[cfg(test)]
78mod test {
79 use rstest::rstest;
80 use vortex_array::arrays::PrimitiveArray;
81 use vortex_array::compute::conformance::take::test_take_conformance;
82 use vortex_array::compute::take;
83 use vortex_array::{Array, ArrayRef, IntoArray, assert_arrays_eq};
84 use vortex_buffer::buffer;
85
86 use crate::RunEndArray;
87
88 fn ree_array() -> RunEndArray {
89 RunEndArray::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()).unwrap()
90 }
91
92 #[test]
93 fn ree_take() {
94 let taken = take(
95 ree_array().as_ref(),
96 buffer![9, 8, 1, 3].into_array().as_ref(),
97 )
98 .unwrap();
99 let expected = PrimitiveArray::from_iter(vec![5i32, 5, 1, 4]).into_array();
100 assert_arrays_eq!(taken, expected);
101 }
102
103 #[test]
104 fn ree_take_end() {
105 let taken = take(ree_array().as_ref(), buffer![11].into_array().as_ref()).unwrap();
106 let expected = PrimitiveArray::from_iter(vec![5i32]).into_array();
107 assert_arrays_eq!(taken, expected);
108 }
109
110 #[test]
111 #[should_panic]
112 fn ree_take_out_of_bounds() {
113 take(ree_array().as_ref(), buffer![12].into_array().as_ref()).unwrap();
114 }
115
116 #[test]
117 fn sliced_take() {
118 let sliced = ree_array().slice(4..9);
119 let taken = take(sliced.as_ref(), buffer![1, 3, 4].into_array().as_ref()).unwrap();
120
121 let expected = PrimitiveArray::from_iter(vec![4i32, 2, 5]).into_array();
122 assert_arrays_eq!(taken, expected);
123 }
124
125 #[test]
126 fn ree_take_nullable() {
127 let taken = take(
128 ree_array().as_ref(),
129 PrimitiveArray::from_option_iter([Some(1), None]).as_ref(),
130 )
131 .unwrap();
132
133 let expected = PrimitiveArray::from_option_iter([Some(1i32), None]);
134 assert_arrays_eq!(taken, expected.to_array());
135 }
136
137 #[rstest]
138 #[case(ree_array())]
139 #[case(RunEndArray::encode(
140 buffer![1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4].into_array(),
141 ).unwrap())]
142 #[case(RunEndArray::encode(
143 PrimitiveArray::from_option_iter([
144 Some(10),
145 Some(10),
146 None,
147 None,
148 Some(20),
149 Some(20),
150 Some(20),
151 ])
152 .into_array(),
153 ).unwrap())]
154 #[case(RunEndArray::encode(buffer![42i32, 42, 42, 42, 42].into_array())
155 .unwrap())]
156 #[case(RunEndArray::encode(
157 buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(),
158 ).unwrap())]
159 #[case({
160 let mut values = Vec::new();
161 for i in 0..20 {
162 for _ in 0..=i {
163 values.push(i);
164 }
165 }
166 RunEndArray::encode(PrimitiveArray::from_iter(values).into_array()).unwrap()
167 })]
168 fn test_take_runend_conformance(#[case] array: RunEndArray) {
169 test_take_conformance(array.as_ref());
170 }
171
172 #[rstest]
173 #[case(ree_array().slice(3..6))]
174 #[case({
175 let array = RunEndArray::encode(
176 buffer![1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3].into_array(),
177 )
178 .unwrap();
179 array.slice(2..8)
180 })]
181 fn test_take_sliced_runend_conformance(#[case] sliced: ArrayRef) {
182 test_take_conformance(sliced.as_ref());
183 }
184}