vortex_runend/compute/
take.rs

1use num_traits::{AsPrimitive, NumCast};
2use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
3use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
4use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
5use vortex_buffer::Buffer;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_error::{VortexResult, vortex_bail};
8
9use crate::{RunEndArray, RunEndVTable};
10
11impl TakeKernel for RunEndVTable {
12    fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
13        let primitive_indices = indices.to_primitive()?;
14
15        let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
16            primitive_indices
17                .as_slice::<$P>()
18                .iter()
19                .copied()
20                .map(|idx| {
21                    let usize_idx = idx as usize;
22                    if usize_idx >= array.len() {
23                        vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
24                    }
25                    Ok(usize_idx)
26                })
27                .collect::<VortexResult<Vec<_>>>()?
28        });
29
30        take_indices_unchecked(array, &checked_indices)
31    }
32}
33
34register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
35
36/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
37pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
38    array: &RunEndArray,
39    indices: &[T],
40) -> VortexResult<ArrayRef> {
41    let ends = array.ends().to_primitive()?;
42    let ends_len = ends.len();
43
44    let physical_indices = match_each_integer_ptype!(ends.ptype(), |$I| {
45        let end_slices = ends.as_slice::<$I>();
46        indices
47            .iter()
48            .map(|idx| idx.as_() + array.offset())
49            .map(|idx| {
50                match <$I as NumCast>::from(idx) {
51                    Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
52                    None => {
53                        // The idx is too large for $I, therefore it's out of bounds.
54                        SearchResult::NotFound(ends_len)
55                    }
56                }
57            })
58            .map(|result| result.to_ends_index(ends_len) as u64)
59            .collect::<Buffer<u64>>()
60            .into_array()
61    });
62
63    take(array.values(), &physical_indices)
64}
65
66#[cfg(test)]
67mod test {
68    use vortex_array::arrays::PrimitiveArray;
69    use vortex_array::compute::take;
70    use vortex_array::{Array, IntoArray, ToCanonical};
71
72    use crate::RunEndArray;
73
74    fn ree_array() -> RunEndArray {
75        RunEndArray::encode(
76            PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(),
77        )
78        .unwrap()
79    }
80
81    #[test]
82    fn ree_take() {
83        let taken = take(
84            ree_array().as_ref(),
85            PrimitiveArray::from_iter([9, 8, 1, 3]).as_ref(),
86        )
87        .unwrap();
88        assert_eq!(
89            taken.to_primitive().unwrap().as_slice::<i32>(),
90            &[5, 5, 1, 4]
91        );
92    }
93
94    #[test]
95    fn ree_take_end() {
96        let taken = take(
97            ree_array().as_ref(),
98            PrimitiveArray::from_iter([11]).as_ref(),
99        )
100        .unwrap();
101        assert_eq!(taken.to_primitive().unwrap().as_slice::<i32>(), &[5]);
102    }
103
104    #[test]
105    #[should_panic]
106    fn ree_take_out_of_bounds() {
107        take(
108            ree_array().as_ref(),
109            PrimitiveArray::from_iter([12]).as_ref(),
110        )
111        .unwrap();
112    }
113
114    #[test]
115    fn sliced_take() {
116        let sliced = ree_array().slice(4, 9).unwrap();
117        let taken = take(
118            sliced.as_ref(),
119            PrimitiveArray::from_iter([1, 3, 4]).as_ref(),
120        )
121        .unwrap();
122
123        assert_eq!(taken.len(), 3);
124        assert_eq!(taken.scalar_at(0).unwrap(), 4.into());
125        assert_eq!(taken.scalar_at(1).unwrap(), 2.into());
126        assert_eq!(taken.scalar_at(2).unwrap(), 5.into());
127    }
128}