Skip to main content

vortex_runend/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::ExecutionCtx;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::TakeExecute;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::search_sorted::SearchResult;
14use vortex_array::search_sorted::SearchSorted;
15use vortex_array::search_sorted::SearchSortedSide;
16use vortex_array::validity::Validity;
17use vortex_array::vtable::ValidityHelper;
18use vortex_buffer::Buffer;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21
22use crate::RunEndArray;
23use crate::RunEndVTable;
24
25impl TakeExecute for RunEndVTable {
26    #[expect(
27        clippy::cast_possible_truncation,
28        reason = "index cast to usize inside macro"
29    )]
30    fn take(
31        array: &RunEndArray,
32        indices: &ArrayRef,
33        _ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        let primitive_indices = indices.to_primitive();
36
37        let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
38            primitive_indices
39                .as_slice::<P>()
40                .iter()
41                .copied()
42                .map(|idx| {
43                    let usize_idx = idx as usize;
44                    if usize_idx >= array.len() {
45                        vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
46                    }
47                    Ok(usize_idx)
48                })
49                .collect::<VortexResult<Vec<_>>>()?
50        });
51
52        take_indices_unchecked(array, &checked_indices, primitive_indices.validity()).map(Some)
53    }
54}
55
56/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
57pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
58    array: &RunEndArray,
59    indices: &[T],
60    validity: &Validity,
61) -> VortexResult<ArrayRef> {
62    let ends = array.ends().to_primitive();
63    let ends_len = ends.len();
64
65    // TODO(joe): use the validity mask to skip search sorted.
66    let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
67        let end_slices = ends.as_slice::<I>();
68        let physical_indices_vec: Vec<u64> = indices
69            .iter()
70            .map(|idx| idx.as_() + array.offset())
71            .map(|idx| {
72                match <I as NumCast>::from(idx) {
73                    Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
74                    None => {
75                        // The idx is too large for I, therefore it's out of bounds.
76                        Ok(SearchResult::NotFound(ends_len))
77                    }
78                }
79            })
80            .map(|result| result.map(|r| r.to_ends_index(ends_len) as u64))
81            .collect::<VortexResult<Vec<_>>>()?;
82        let buffer = Buffer::from(physical_indices_vec);
83
84        PrimitiveArray::new(buffer, validity.clone())
85    });
86
87    array.values().take(physical_indices.to_array())
88}
89
90#[cfg(test)]
91mod test {
92    use rstest::rstest;
93    use vortex_array::Array;
94    use vortex_array::ArrayRef;
95    use vortex_array::Canonical;
96    use vortex_array::IntoArray;
97    use vortex_array::LEGACY_SESSION;
98    use vortex_array::VortexSessionExecute;
99    use vortex_array::arrays::PrimitiveArray;
100    use vortex_array::assert_arrays_eq;
101    use vortex_array::compute::conformance::take::test_take_conformance;
102    use vortex_buffer::buffer;
103
104    use crate::RunEndArray;
105
106    fn ree_array() -> RunEndArray {
107        RunEndArray::encode(buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array()).unwrap()
108    }
109
110    #[test]
111    fn ree_take() {
112        let taken = ree_array().take(buffer![9, 8, 1, 3].into_array()).unwrap();
113        let expected = PrimitiveArray::from_iter(vec![5i32, 5, 1, 4]).into_array();
114        assert_arrays_eq!(taken, expected);
115    }
116
117    #[test]
118    fn ree_take_end() {
119        let taken = ree_array().take(buffer![11].into_array()).unwrap();
120        let expected = PrimitiveArray::from_iter(vec![5i32]).into_array();
121        assert_arrays_eq!(taken, expected);
122    }
123
124    #[test]
125    #[should_panic]
126    fn ree_take_out_of_bounds() {
127        let _array = ree_array()
128            .take(buffer![12].into_array())
129            .unwrap()
130            .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
131            .unwrap();
132    }
133
134    #[test]
135    fn sliced_take() {
136        let sliced = ree_array().slice(4..9).unwrap();
137        let taken = sliced.take(buffer![1, 3, 4].into_array()).unwrap();
138
139        let expected = PrimitiveArray::from_iter(vec![4i32, 2, 5]).into_array();
140        assert_arrays_eq!(taken, expected);
141    }
142
143    #[test]
144    fn ree_take_nullable() {
145        let taken = ree_array()
146            .take(PrimitiveArray::from_option_iter([Some(1), None]).to_array())
147            .unwrap();
148
149        let expected = PrimitiveArray::from_option_iter([Some(1i32), None]);
150        assert_arrays_eq!(taken, expected.to_array());
151    }
152
153    #[rstest]
154    #[case(ree_array())]
155    #[case(RunEndArray::encode(
156        buffer![1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4].into_array(),
157    ).unwrap())]
158    #[case(RunEndArray::encode(
159        PrimitiveArray::from_option_iter([
160            Some(10),
161            Some(10),
162            None,
163            None,
164            Some(20),
165            Some(20),
166            Some(20),
167        ])
168        .into_array(),
169    ).unwrap())]
170    #[case(RunEndArray::encode(buffer![42i32, 42, 42, 42, 42].into_array())
171        .unwrap())]
172    #[case(RunEndArray::encode(
173        buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(),
174    ).unwrap())]
175    #[case({
176        let mut values = Vec::new();
177        for i in 0..20 {
178            for _ in 0..=i {
179                values.push(i);
180            }
181        }
182        RunEndArray::encode(PrimitiveArray::from_iter(values).into_array()).unwrap()
183    })]
184    fn test_take_runend_conformance(#[case] array: RunEndArray) {
185        test_take_conformance(&array.to_array());
186    }
187
188    #[rstest]
189    #[case(ree_array().slice(3..6).unwrap())]
190    #[case({
191        let array = RunEndArray::encode(
192            buffer![1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3].into_array(),
193        )
194        .unwrap();
195        array.slice(2..8).unwrap()
196    })]
197    fn test_take_sliced_runend_conformance(#[case] sliced: ArrayRef) {
198        test_take_conformance(&sliced);
199    }
200}