vortex_runend/compute/
take.rs

1use num_traits::{AsPrimitive, NumCast};
2use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
3use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
4use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
5use vortex_buffer::Buffer;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_error::{VortexResult, vortex_bail};
8
9use crate::{RunEndArray, RunEndVTable};
10
11impl TakeKernel for RunEndVTable {
12    #[allow(clippy::cast_possible_truncation)]
13    fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
14        let primitive_indices = indices.to_primitive()?;
15
16        let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
17            primitive_indices
18                .as_slice::<P>()
19                .iter()
20                .copied()
21                .map(|idx| {
22                    let usize_idx = idx as usize;
23                    if usize_idx >= array.len() {
24                        vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
25                    }
26                    Ok(usize_idx)
27                })
28                .collect::<VortexResult<Vec<_>>>()?
29        });
30
31        take_indices_unchecked(array, &checked_indices)
32    }
33}
34
35register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
36
37/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
38pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
39    array: &RunEndArray,
40    indices: &[T],
41) -> VortexResult<ArrayRef> {
42    let ends = array.ends().to_primitive()?;
43    let ends_len = ends.len();
44
45    let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
46        let end_slices = ends.as_slice::<I>();
47        indices
48            .iter()
49            .map(|idx| idx.as_() + array.offset())
50            .map(|idx| {
51                match <I as NumCast>::from(idx) {
52                    Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
53                    None => {
54                        // The idx is too large for I, therefore it's out of bounds.
55                        SearchResult::NotFound(ends_len)
56                    }
57                }
58            })
59            .map(|result| result.to_ends_index(ends_len) as u64)
60            .collect::<Buffer<u64>>()
61            .into_array()
62    });
63
64    take(array.values(), &physical_indices)
65}
66
67#[cfg(test)]
68mod test {
69    use vortex_array::arrays::PrimitiveArray;
70    use vortex_array::compute::take;
71    use vortex_array::{Array, IntoArray, ToCanonical};
72
73    use crate::RunEndArray;
74
75    fn ree_array() -> RunEndArray {
76        RunEndArray::encode(
77            PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(),
78        )
79        .unwrap()
80    }
81
82    #[test]
83    fn ree_take() {
84        let taken = take(
85            ree_array().as_ref(),
86            PrimitiveArray::from_iter([9, 8, 1, 3]).as_ref(),
87        )
88        .unwrap();
89        assert_eq!(
90            taken.to_primitive().unwrap().as_slice::<i32>(),
91            &[5, 5, 1, 4]
92        );
93    }
94
95    #[test]
96    fn ree_take_end() {
97        let taken = take(
98            ree_array().as_ref(),
99            PrimitiveArray::from_iter([11]).as_ref(),
100        )
101        .unwrap();
102        assert_eq!(taken.to_primitive().unwrap().as_slice::<i32>(), &[5]);
103    }
104
105    #[test]
106    #[should_panic]
107    fn ree_take_out_of_bounds() {
108        take(
109            ree_array().as_ref(),
110            PrimitiveArray::from_iter([12]).as_ref(),
111        )
112        .unwrap();
113    }
114
115    #[test]
116    fn sliced_take() {
117        let sliced = ree_array().slice(4, 9).unwrap();
118        let taken = take(
119            sliced.as_ref(),
120            PrimitiveArray::from_iter([1, 3, 4]).as_ref(),
121        )
122        .unwrap();
123
124        assert_eq!(taken.len(), 3);
125        assert_eq!(taken.scalar_at(0).unwrap(), 4.into());
126        assert_eq!(taken.scalar_at(1).unwrap(), 2.into());
127        assert_eq!(taken.scalar_at(2).unwrap(), 5.into());
128    }
129}