vortex_runend/compute/
take.rs

1use num_traits::{AsPrimitive, NumCast};
2use vortex_array::arrays::PrimitiveArray;
3use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
4use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
5use vortex_array::validity::Validity;
6use vortex_array::vtable::ValidityHelper;
7use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
8use vortex_buffer::Buffer;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::{VortexResult, vortex_bail};
11
12use crate::{RunEndArray, RunEndVTable};
13
14impl TakeKernel for RunEndVTable {
15    #[allow(clippy::cast_possible_truncation)]
16    fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17        let primitive_indices = indices.to_primitive()?;
18
19        let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
20            primitive_indices
21                .as_slice::<P>()
22                .iter()
23                .copied()
24                .map(|idx| {
25                    let usize_idx = idx as usize;
26                    if usize_idx >= array.len() {
27                        vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
28                    }
29                    Ok(usize_idx)
30                })
31                .collect::<VortexResult<Vec<_>>>()?
32        });
33
34        take_indices_unchecked(array, &checked_indices, primitive_indices.validity())
35    }
36}
37
38register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
39
40/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
41pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
42    array: &RunEndArray,
43    indices: &[T],
44    validity: &Validity,
45) -> VortexResult<ArrayRef> {
46    let ends = array.ends().to_primitive()?;
47    let ends_len = ends.len();
48
49    // TODO(joe): use the validity mask to skip search sorted.
50    let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
51        let end_slices = ends.as_slice::<I>();
52        let buffer = indices
53            .iter()
54            .map(|idx| idx.as_() + array.offset())
55            .map(|idx| {
56                match <I as NumCast>::from(idx) {
57                    Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
58                    None => {
59                        // The idx is too large for I, therefore it's out of bounds.
60                        SearchResult::NotFound(ends_len)
61                    }
62                }
63            })
64            .map(|result| result.to_ends_index(ends_len) as u64)
65            .collect::<Buffer<u64>>();
66
67        PrimitiveArray::new(buffer, validity.clone())
68    });
69
70    take(array.values(), physical_indices.as_ref())
71}
72
73#[cfg(test)]
74mod test {
75    use vortex_array::arrays::PrimitiveArray;
76    use vortex_array::compute::take;
77    use vortex_array::{Array, IntoArray, ToCanonical};
78    use vortex_dtype::{DType, Nullability, PType};
79    use vortex_scalar::{Scalar, ScalarValue};
80
81    use crate::RunEndArray;
82
83    fn ree_array() -> RunEndArray {
84        RunEndArray::encode(
85            PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(),
86        )
87        .unwrap()
88    }
89
90    #[test]
91    fn ree_take() {
92        let taken = take(
93            ree_array().as_ref(),
94            PrimitiveArray::from_iter([9, 8, 1, 3]).as_ref(),
95        )
96        .unwrap();
97        assert_eq!(
98            taken.to_primitive().unwrap().as_slice::<i32>(),
99            &[5, 5, 1, 4]
100        );
101    }
102
103    #[test]
104    fn ree_take_end() {
105        let taken = take(
106            ree_array().as_ref(),
107            PrimitiveArray::from_iter([11]).as_ref(),
108        )
109        .unwrap();
110        assert_eq!(taken.to_primitive().unwrap().as_slice::<i32>(), &[5]);
111    }
112
113    #[test]
114    #[should_panic]
115    fn ree_take_out_of_bounds() {
116        take(
117            ree_array().as_ref(),
118            PrimitiveArray::from_iter([12]).as_ref(),
119        )
120        .unwrap();
121    }
122
123    #[test]
124    fn sliced_take() {
125        let sliced = ree_array().slice(4, 9).unwrap();
126        let taken = take(
127            sliced.as_ref(),
128            PrimitiveArray::from_iter([1, 3, 4]).as_ref(),
129        )
130        .unwrap();
131
132        assert_eq!(taken.len(), 3);
133        assert_eq!(taken.scalar_at(0).unwrap(), 4.into());
134        assert_eq!(taken.scalar_at(1).unwrap(), 2.into());
135        assert_eq!(taken.scalar_at(2).unwrap(), 5.into());
136    }
137
138    #[test]
139    fn ree_take_nullable() {
140        let taken = take(
141            ree_array().as_ref(),
142            PrimitiveArray::from_option_iter([Some(1), None]).as_ref(),
143        )
144        .unwrap();
145
146        assert_eq!(
147            taken.scalar_at(0).unwrap(),
148            Scalar::new(
149                DType::Primitive(PType::I32, Nullability::Nullable),
150                ScalarValue::from(1i32)
151            )
152        );
153        assert_eq!(
154            taken.scalar_at(1).unwrap(),
155            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
156        );
157    }
158}