vortex_runend/compute/
take.rs1use num_traits::{AsPrimitive, NumCast};
2use vortex_array::arrays::PrimitiveArray;
3use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
4use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
5use vortex_array::validity::Validity;
6use vortex_array::vtable::ValidityHelper;
7use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
8use vortex_buffer::Buffer;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::{VortexResult, vortex_bail};
11
12use crate::{RunEndArray, RunEndVTable};
13
14impl TakeKernel for RunEndVTable {
15 #[allow(clippy::cast_possible_truncation)]
16 fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17 let primitive_indices = indices.to_primitive()?;
18
19 let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
20 primitive_indices
21 .as_slice::<P>()
22 .iter()
23 .copied()
24 .map(|idx| {
25 let usize_idx = idx as usize;
26 if usize_idx >= array.len() {
27 vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
28 }
29 Ok(usize_idx)
30 })
31 .collect::<VortexResult<Vec<_>>>()?
32 });
33
34 take_indices_unchecked(array, &checked_indices, primitive_indices.validity())
35 }
36}
37
38register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
39
40pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
42 array: &RunEndArray,
43 indices: &[T],
44 validity: &Validity,
45) -> VortexResult<ArrayRef> {
46 let ends = array.ends().to_primitive()?;
47 let ends_len = ends.len();
48
49 let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
51 let end_slices = ends.as_slice::<I>();
52 let buffer = indices
53 .iter()
54 .map(|idx| idx.as_() + array.offset())
55 .map(|idx| {
56 match <I as NumCast>::from(idx) {
57 Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
58 None => {
59 SearchResult::NotFound(ends_len)
61 }
62 }
63 })
64 .map(|result| result.to_ends_index(ends_len) as u64)
65 .collect::<Buffer<u64>>();
66
67 PrimitiveArray::new(buffer, validity.clone())
68 });
69
70 take(array.values(), physical_indices.as_ref())
71}
72
73#[cfg(test)]
74mod test {
75 use vortex_array::arrays::PrimitiveArray;
76 use vortex_array::compute::take;
77 use vortex_array::{Array, IntoArray, ToCanonical};
78 use vortex_dtype::{DType, Nullability, PType};
79 use vortex_scalar::{Scalar, ScalarValue};
80
81 use crate::RunEndArray;
82
83 fn ree_array() -> RunEndArray {
84 RunEndArray::encode(
85 PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(),
86 )
87 .unwrap()
88 }
89
90 #[test]
91 fn ree_take() {
92 let taken = take(
93 ree_array().as_ref(),
94 PrimitiveArray::from_iter([9, 8, 1, 3]).as_ref(),
95 )
96 .unwrap();
97 assert_eq!(
98 taken.to_primitive().unwrap().as_slice::<i32>(),
99 &[5, 5, 1, 4]
100 );
101 }
102
103 #[test]
104 fn ree_take_end() {
105 let taken = take(
106 ree_array().as_ref(),
107 PrimitiveArray::from_iter([11]).as_ref(),
108 )
109 .unwrap();
110 assert_eq!(taken.to_primitive().unwrap().as_slice::<i32>(), &[5]);
111 }
112
113 #[test]
114 #[should_panic]
115 fn ree_take_out_of_bounds() {
116 take(
117 ree_array().as_ref(),
118 PrimitiveArray::from_iter([12]).as_ref(),
119 )
120 .unwrap();
121 }
122
123 #[test]
124 fn sliced_take() {
125 let sliced = ree_array().slice(4, 9).unwrap();
126 let taken = take(
127 sliced.as_ref(),
128 PrimitiveArray::from_iter([1, 3, 4]).as_ref(),
129 )
130 .unwrap();
131
132 assert_eq!(taken.len(), 3);
133 assert_eq!(taken.scalar_at(0).unwrap(), 4.into());
134 assert_eq!(taken.scalar_at(1).unwrap(), 2.into());
135 assert_eq!(taken.scalar_at(2).unwrap(), 5.into());
136 }
137
138 #[test]
139 fn ree_take_nullable() {
140 let taken = take(
141 ree_array().as_ref(),
142 PrimitiveArray::from_option_iter([Some(1), None]).as_ref(),
143 )
144 .unwrap();
145
146 assert_eq!(
147 taken.scalar_at(0).unwrap(),
148 Scalar::new(
149 DType::Primitive(PType::I32, Nullability::Nullable),
150 ScalarValue::from(1i32)
151 )
152 );
153 assert_eq!(
154 taken.scalar_at(1).unwrap(),
155 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
156 );
157 }
158}