Skip to main content

vortex_runend/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::dict::TakeExecute;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::search_sorted::SearchResult;
14use vortex_array::search_sorted::SearchSorted;
15use vortex_array::search_sorted::SearchSortedSide;
16use vortex_array::validity::Validity;
17use vortex_buffer::Buffer;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20
21use crate::RunEnd;
22use crate::array::RunEndArrayExt;
23
24impl TakeExecute for RunEnd {
25    #[expect(
26        clippy::cast_possible_truncation,
27        reason = "index cast to usize inside macro"
28    )]
29    fn take(
30        array: ArrayView<'_, Self>,
31        indices: &ArrayRef,
32        ctx: &mut ExecutionCtx,
33    ) -> VortexResult<Option<ArrayRef>> {
34        let primitive_indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
35
36        let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
37            primitive_indices
38                .as_slice::<P>()
39                .iter()
40                .copied()
41                .map(|idx| {
42                    let usize_idx = idx as usize;
43                    if usize_idx >= array.len() {
44                        vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
45                    }
46                    Ok(usize_idx)
47                })
48                .collect::<VortexResult<Vec<_>>>()?
49        });
50
51        let indices_validity = primitive_indices.validity()?;
52        take_indices_unchecked(array, &checked_indices, &indices_validity, ctx).map(Some)
53    }
54}
55
56/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
57pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
58    array: ArrayView<'_, RunEnd>,
59    indices: &[T],
60    validity: &Validity,
61    ctx: &mut ExecutionCtx,
62) -> VortexResult<ArrayRef> {
63    let ends = array.ends().clone().execute::<PrimitiveArray>(ctx)?;
64    let ends_len = ends.len();
65
66    // TODO(joe): use the validity mask to skip search sorted.
67    let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
68        let end_slices = ends.as_slice::<I>();
69        let physical_indices_vec: Vec<u64> = indices
70            .iter()
71            .map(|idx| idx.as_() + array.offset())
72            .map(|idx| {
73                match <I as NumCast>::from(idx) {
74                    Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
75                    None => {
76                        // The idx is too large for I, therefore it's out of bounds.
77                        Ok(SearchResult::NotFound(ends_len))
78                    }
79                }
80            })
81            .map(|result| result.map(|r| r.to_ends_index(ends_len) as u64))
82            .collect::<VortexResult<Vec<_>>>()?;
83        let buffer = Buffer::from(physical_indices_vec);
84
85        PrimitiveArray::new(buffer, validity.clone())
86    });
87
88    array.values().take(physical_indices.into_array())
89}
90
91#[cfg(test)]
92mod tests {
93    use rstest::rstest;
94    use vortex_array::ArrayRef;
95    use vortex_array::Canonical;
96    use vortex_array::IntoArray;
97    use vortex_array::LEGACY_SESSION;
98    use vortex_array::VortexSessionExecute;
99    use vortex_array::arrays::PrimitiveArray;
100    use vortex_array::assert_arrays_eq;
101    use vortex_array::compute::conformance::take::test_take_conformance;
102    use vortex_buffer::buffer;
103
104    use crate::RunEnd;
105    use crate::RunEndArray;
106
107    fn ree_array() -> RunEndArray {
108        RunEnd::encode(
109            buffer![1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5].into_array(),
110            &mut LEGACY_SESSION.create_execution_ctx(),
111        )
112        .unwrap()
113    }
114
115    #[test]
116    fn ree_take() {
117        let taken = ree_array().take(buffer![9, 8, 1, 3].into_array()).unwrap();
118        let expected = PrimitiveArray::from_iter(vec![5i32, 5, 1, 4]).into_array();
119        assert_arrays_eq!(taken, expected);
120    }
121
122    #[test]
123    fn ree_take_end() {
124        let taken = ree_array().take(buffer![11].into_array()).unwrap();
125        let expected = PrimitiveArray::from_iter(vec![5i32]).into_array();
126        assert_arrays_eq!(taken, expected);
127    }
128
129    #[test]
130    #[should_panic]
131    fn ree_take_out_of_bounds() {
132        let _array = ree_array()
133            .take(buffer![12].into_array())
134            .unwrap()
135            .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
136            .unwrap();
137    }
138
139    #[test]
140    fn sliced_take() {
141        let sliced = ree_array().slice(4..9).unwrap();
142        let taken = sliced.take(buffer![1, 3, 4].into_array()).unwrap();
143
144        let expected = PrimitiveArray::from_iter(vec![4i32, 2, 5]).into_array();
145        assert_arrays_eq!(taken, expected);
146    }
147
148    #[test]
149    fn ree_take_nullable() {
150        let taken = ree_array()
151            .take(PrimitiveArray::from_option_iter([Some(1), None]).into_array())
152            .unwrap();
153
154        let expected = PrimitiveArray::from_option_iter([Some(1i32), None]);
155        assert_arrays_eq!(taken, expected.into_array());
156    }
157
158    #[rstest]
159    #[case(ree_array())]
160    #[case(RunEnd::encode(
161        buffer![1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4].into_array(),
162        &mut LEGACY_SESSION.create_execution_ctx(),
163    ).unwrap())]
164    #[case(RunEnd::encode(
165        PrimitiveArray::from_option_iter([
166            Some(10),
167            Some(10),
168            None,
169            None,
170            Some(20),
171            Some(20),
172            Some(20),
173        ])
174        .into_array(),
175        &mut LEGACY_SESSION.create_execution_ctx(),
176    ).unwrap())]
177    #[case(RunEnd::encode(buffer![42i32, 42, 42, 42, 42].into_array(),
178        &mut LEGACY_SESSION.create_execution_ctx())
179        .unwrap())]
180    #[case(RunEnd::encode(
181        buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(),
182        &mut LEGACY_SESSION.create_execution_ctx(),
183    ).unwrap())]
184    #[case({
185        let mut values = Vec::new();
186        for i in 0..20 {
187            for _ in 0..=i {
188                values.push(i);
189            }
190        }
191        RunEnd::encode(
192            PrimitiveArray::from_iter(values).into_array(),
193            &mut LEGACY_SESSION.create_execution_ctx(),
194        )
195        .unwrap()
196    })]
197    fn test_take_runend_conformance(#[case] array: RunEndArray) {
198        test_take_conformance(&array.into_array());
199    }
200
201    #[rstest]
202    #[case(ree_array().slice(3..6).unwrap())]
203    #[case({
204        let array = RunEnd::encode(
205            buffer![1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3].into_array(),
206            &mut LEGACY_SESSION.create_execution_ctx(),
207        )
208        .unwrap();
209        array.slice(2..8).unwrap()
210    })]
211    fn test_take_sliced_runend_conformance(#[case] sliced: ArrayRef) {
212        test_take_conformance(&sliced);
213    }
214}