vortex_runend/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::{AsPrimitive, NumCast};
5use vortex_array::arrays::PrimitiveArray;
6use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
7use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
8use vortex_array::validity::Validity;
9use vortex_array::vtable::ValidityHelper;
10use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
11use vortex_buffer::Buffer;
12use vortex_dtype::match_each_integer_ptype;
13use vortex_error::{VortexResult, vortex_bail};
14
15use crate::{RunEndArray, RunEndVTable};
16
17impl TakeKernel for RunEndVTable {
18    #[allow(clippy::cast_possible_truncation)]
19    fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20        let primitive_indices = indices.to_primitive()?;
21
22        let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| {
23            primitive_indices
24                .as_slice::<P>()
25                .iter()
26                .copied()
27                .map(|idx| {
28                    let usize_idx = idx as usize;
29                    if usize_idx >= array.len() {
30                        vortex_bail!(OutOfBounds: usize_idx, 0, array.len());
31                    }
32                    Ok(usize_idx)
33                })
34                .collect::<VortexResult<Vec<_>>>()?
35        });
36
37        take_indices_unchecked(array, &checked_indices, primitive_indices.validity())
38    }
39}
40
41register_kernel!(TakeKernelAdapter(RunEndVTable).lift());
42
43/// Perform a take operation on a RunEndArray by binary searching for each of the indices.
44pub fn take_indices_unchecked<T: AsPrimitive<usize>>(
45    array: &RunEndArray,
46    indices: &[T],
47    validity: &Validity,
48) -> VortexResult<ArrayRef> {
49    let ends = array.ends().to_primitive()?;
50    let ends_len = ends.len();
51
52    // TODO(joe): use the validity mask to skip search sorted.
53    let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| {
54        let end_slices = ends.as_slice::<I>();
55        let buffer = Buffer::from_trusted_len_iter(
56            indices
57                .iter()
58                .map(|idx| idx.as_() + array.offset())
59                .map(|idx| {
60                    match <I as NumCast>::from(idx) {
61                        Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right),
62                        None => {
63                            // The idx is too large for I, therefore it's out of bounds.
64                            SearchResult::NotFound(ends_len)
65                        }
66                    }
67                })
68                .map(|result| result.to_ends_index(ends_len) as u64),
69        );
70
71        PrimitiveArray::new(buffer, validity.clone())
72    });
73
74    take(array.values(), physical_indices.as_ref())
75}
76
77#[cfg(test)]
78mod test {
79    use rstest::rstest;
80    use vortex_array::arrays::PrimitiveArray;
81    use vortex_array::compute::conformance::take::test_take_conformance;
82    use vortex_array::compute::take;
83    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
84    use vortex_dtype::{DType, Nullability, PType};
85    use vortex_scalar::{Scalar, ScalarValue};
86
87    use crate::RunEndArray;
88
89    fn ree_array() -> RunEndArray {
90        RunEndArray::encode(
91            PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(),
92        )
93        .unwrap()
94    }
95
96    #[test]
97    fn ree_take() {
98        let taken = take(
99            ree_array().as_ref(),
100            PrimitiveArray::from_iter([9, 8, 1, 3]).as_ref(),
101        )
102        .unwrap();
103        assert_eq!(
104            taken.to_primitive().unwrap().as_slice::<i32>(),
105            &[5, 5, 1, 4]
106        );
107    }
108
109    #[test]
110    fn ree_take_end() {
111        let taken = take(
112            ree_array().as_ref(),
113            PrimitiveArray::from_iter([11]).as_ref(),
114        )
115        .unwrap();
116        assert_eq!(taken.to_primitive().unwrap().as_slice::<i32>(), &[5]);
117    }
118
119    #[test]
120    #[should_panic]
121    fn ree_take_out_of_bounds() {
122        take(
123            ree_array().as_ref(),
124            PrimitiveArray::from_iter([12]).as_ref(),
125        )
126        .unwrap();
127    }
128
129    #[test]
130    fn sliced_take() {
131        let sliced = ree_array().slice(4..9);
132        let taken = take(
133            sliced.as_ref(),
134            PrimitiveArray::from_iter([1, 3, 4]).as_ref(),
135        )
136        .unwrap();
137
138        assert_eq!(taken.len(), 3);
139        assert_eq!(taken.scalar_at(0), 4.into());
140        assert_eq!(taken.scalar_at(1), 2.into());
141        assert_eq!(taken.scalar_at(2), 5.into());
142    }
143
144    #[test]
145    fn ree_take_nullable() {
146        let taken = take(
147            ree_array().as_ref(),
148            PrimitiveArray::from_option_iter([Some(1), None]).as_ref(),
149        )
150        .unwrap();
151
152        assert_eq!(
153            taken.scalar_at(0),
154            Scalar::new(
155                DType::Primitive(PType::I32, Nullability::Nullable),
156                ScalarValue::from(1i32)
157            )
158        );
159        assert_eq!(
160            taken.scalar_at(1),
161            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
162        );
163    }
164
165    #[rstest]
166    #[case(ree_array())]
167    #[case(RunEndArray::encode(
168        PrimitiveArray::from_iter([1u8, 1, 2, 2, 2, 3, 3, 3, 3, 4]).into_array(),
169    ).unwrap())]
170    #[case(RunEndArray::encode(
171        PrimitiveArray::from_option_iter([
172            Some(10),
173            Some(10),
174            None,
175            None,
176            Some(20),
177            Some(20),
178            Some(20),
179        ])
180        .into_array(),
181    ).unwrap())]
182    #[case(RunEndArray::encode(PrimitiveArray::from_iter([42i32, 42, 42, 42, 42]).into_array())
183        .unwrap())]
184    #[case(RunEndArray::encode(
185        PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10]).into_array(),
186    ).unwrap())]
187    #[case({
188        let mut values = Vec::new();
189        for i in 0..20 {
190            for _ in 0..=i {
191                values.push(i);
192            }
193        }
194        RunEndArray::encode(PrimitiveArray::from_iter(values).into_array()).unwrap()
195    })]
196    fn test_take_runend_conformance(#[case] array: RunEndArray) {
197        test_take_conformance(array.as_ref());
198    }
199
200    #[rstest]
201    #[case(ree_array().slice(3..6))]
202    #[case({
203        let array = RunEndArray::encode(
204            PrimitiveArray::from_iter([1i32, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).into_array(),
205        )
206        .unwrap();
207        array.slice(2..8)
208    })]
209    fn test_take_sliced_runend_conformance(#[case] sliced: ArrayRef) {
210        test_take_conformance(sliced.as_ref());
211    }
212}