Skip to main content

vortex_runend/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_array::ArrayRef;
7use vortex_array::DynArray;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::ToCanonical;
11use vortex_array::arrays::PrimitiveArray;
12use vortex_array::arrays::dict::TakeExecute;
13use vortex_array::match_each_integer_ptype;
14use vortex_array::search_sorted::SearchResult;
15use vortex_array::search_sorted::SearchSorted;
16use vortex_array::search_sorted::SearchSortedSide;
17use vortex_array::validity::Validity;
18use vortex_array::vtable::ValidityHelper;
19use vortex_buffer::Buffer;
20use vortex_error::VortexResult;
21use vortex_error::vortex_bail;
22
23use crate::RunEndArray;
24use crate::RunEndVTable;
25
26impl TakeExecute for RunEndVTable {
27    #[expect(
28        clippy::cast_possible_truncation,
29        reason = "index cast to usize inside macro"
30    )]
31    fn take(
32        array: &RunEndArray,
33        indices: &ArrayRef,
34        ctx: &mut ExecutionCtx,
35    ) -> VortexResult<Option<ArrayRef>> {
36        let primitive_indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
37
38        let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
39            primitive_indices
40                .as_slice::<P>()
41                .iter()
42                .copied()
43                .map(|idx| {
44                    let usize_idx = idx as usize;
45                    if usize_idx >= array.len() {
46                        vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
47                    }
48                    Ok(usize_idx)
49                })
50                .collect::<VortexResult<Vec<_>>>()?
51        });
52
53        take_indices_unchecked(array, &checked_indices, primitive_indices.validity()).map(Some)
54    }
55}
56
57/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
58pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
59    array: &RunEndArray,
60    indices: &[T],
61    validity: &Validity,
62) -> VortexResult<ArrayRef> {
63    let ends = array.ends().to_primitive();
64    let ends_len = ends.len();
65
66    // TODO(joe): use the validity mask to skip search sorted.
67    let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
68        let end_slices = ends.as_slice::<I>();
69        let physical_indices_vec: Vec<u64> = indices
70            .iter()
71            .map(|idx| idx.as_() + array.offset())
72            .map(|idx| {
73                match <I as NumCast>::from(idx) {
74                    Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
75                    None => {
76                        // The idx is too large for I, therefore it's out of bounds.
77                        Ok(SearchResult::NotFound(ends_len))
78                    }
79                }
80            })
81            .map(|result| result.map(|r| r.to_ends_index(ends_len) as u64))
82            .collect::<VortexResult<Vec<_>>>()?;
83        let buffer = Buffer::from(physical_indices_vec);
84
85        PrimitiveArray::new(buffer, validity.clone())
86    });
87
88    array.values().take(physical_indices.into_array())
89}
90
91#[cfg(test)]
92mod test {
93    use rstest::rstest;
94    use vortex_array::ArrayRef;
95    use vortex_array::Canonical;
96    use vortex_array::DynArray;
97    use vortex_array::IntoArray;
98    use vortex_array::LEGACY_SESSION;
99    use vortex_array::VortexSessionExecute;
100    use vortex_array::arrays::PrimitiveArray;
101    use vortex_array::assert_arrays_eq;
102    use vortex_array::compute::conformance::take::test_take_conformance;
103    use vortex_buffer::buffer;
104
105    use crate::RunEndArray;
106
107    fn ree_array() -> RunEndArray {
108        RunEndArray::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()).unwrap()
109    }
110
111    #[test]
112    fn ree_take() {
113        let taken = ree_array().take(buffer![9, 8, 1, 3].into_array()).unwrap();
114        let expected = PrimitiveArray::from_iter(vec![5i32, 5, 1, 4]).into_array();
115        assert_arrays_eq!(taken, expected);
116    }
117
118    #[test]
119    fn ree_take_end() {
120        let taken = ree_array().take(buffer![11].into_array()).unwrap();
121        let expected = PrimitiveArray::from_iter(vec![5i32]).into_array();
122        assert_arrays_eq!(taken, expected);
123    }
124
125    #[test]
126    #[should_panic]
127    fn ree_take_out_of_bounds() {
128        let _array = ree_array()
129            .take(buffer![12].into_array())
130            .unwrap()
131            .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
132            .unwrap();
133    }
134
135    #[test]
136    fn sliced_take() {
137        let sliced = ree_array().slice(4..9).unwrap();
138        let taken = sliced.take(buffer![1, 3, 4].into_array()).unwrap();
139
140        let expected = PrimitiveArray::from_iter(vec![4i32, 2, 5]).into_array();
141        assert_arrays_eq!(taken, expected);
142    }
143
144    #[test]
145    fn ree_take_nullable() {
146        let taken = ree_array()
147            .take(PrimitiveArray::from_option_iter([Some(1), None]).into_array())
148            .unwrap();
149
150        let expected = PrimitiveArray::from_option_iter([Some(1i32), None]);
151        assert_arrays_eq!(taken, expected.into_array());
152    }
153
154    #[rstest]
155    #[case(ree_array())]
156    #[case(RunEndArray::encode(
157        buffer![1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4].into_array(),
158    ).unwrap())]
159    #[case(RunEndArray::encode(
160        PrimitiveArray::from_option_iter([
161            Some(10),
162            Some(10),
163            None,
164            None,
165            Some(20),
166            Some(20),
167            Some(20),
168        ])
169        .into_array(),
170    ).unwrap())]
171    #[case(RunEndArray::encode(buffer![42i32, 42, 42, 42, 42].into_array())
172        .unwrap())]
173    #[case(RunEndArray::encode(
174        buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(),
175    ).unwrap())]
176    #[case({
177        let mut values = Vec::new();
178        for i in 0..20 {
179            for _ in 0..=i {
180                values.push(i);
181            }
182        }
183        RunEndArray::encode(PrimitiveArray::from_iter(values).into_array()).unwrap()
184    })]
185    fn test_take_runend_conformance(#[case] array: RunEndArray) {
186        test_take_conformance(&array.into_array());
187    }
188
189    #[rstest]
190    #[case(ree_array().slice(3..6).unwrap())]
191    #[case({
192        let array = RunEndArray::encode(
193            buffer![1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3].into_array(),
194        )
195        .unwrap();
196        array.slice(2..8).unwrap()
197    })]
198    fn test_take_sliced_runend_conformance(#[case] sliced: ArrayRef) {
199        test_take_conformance(&sliced);
200    }
201}