Skip to main content

vortex_sequence/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::cast::NumCast;
5use vortex_array::ArrayRef;
6use vortex_array::ArrayView;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::dict::TakeExecute;
12use vortex_array::dtype::DType;
13use vortex_array::dtype::IntegerPType;
14use vortex_array::dtype::NativePType;
15use vortex_array::dtype::Nullability;
16use vortex_array::match_each_integer_ptype;
17use vortex_array::match_each_native_ptype;
18use vortex_array::scalar::Scalar;
19use vortex_array::validity::Validity;
20use vortex_buffer::Buffer;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_panic;
24use vortex_mask::AllOr;
25use vortex_mask::Mask;
26
27use crate::Sequence;
28
29fn take_inner<T: IntegerPType, S: NativePType>(
30    mul: S,
31    base: S,
32    indices: &[T],
33    indices_mask: Mask,
34    result_nullability: Nullability,
35    len: usize,
36) -> ArrayRef {
37    match indices_mask.bit_buffer() {
38        AllOr::All => PrimitiveArray::new(
39            Buffer::from_trusted_len_iter(indices.iter().map(|i| {
40                if i.as_() >= len {
41                    vortex_panic!(OutOfBounds: i.as_(), 0, len);
42                }
43                let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
44                base + i * mul
45            })),
46            Validity::from(result_nullability),
47        )
48        .into_array(),
49        AllOr::None => ConstantArray::new(
50            Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
51            indices.len(),
52        )
53        .into_array(),
54        AllOr::Some(b) => {
55            let buffer =
56                Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
57                    if b.value(mask_index) {
58                        if i.as_() >= len {
59                            vortex_panic!(OutOfBounds: i.as_(), 0, len);
60                        }
61
62                        let i =
63                            <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
64                        base + i * mul
65                    } else {
66                        S::zero()
67                    }
68                }));
69            PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
70        }
71    }
72}
73
74impl TakeExecute for Sequence {
75    fn take(
76        array: ArrayView<'_, Self>,
77        indices: &ArrayRef,
78        ctx: &mut ExecutionCtx,
79    ) -> VortexResult<Option<ArrayRef>> {
80        let mask = indices.validity()?.execute_mask(indices.len(), ctx)?;
81        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
82        let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
83
84        match_each_integer_ptype!(indices.ptype(), |T| {
85            let indices = indices.as_slice::<T>();
86            match_each_native_ptype!(array.ptype(), |S| {
87                let mul = array.multiplier().cast::<S>()?;
88                let base = array.base().cast::<S>()?;
89                Ok(Some(take_inner(
90                    mul,
91                    base,
92                    indices,
93                    mask,
94                    result_nullability,
95                    array.len(),
96                )))
97            })
98        })
99    }
100}
101
102#[cfg(test)]
103mod test {
104    use rstest::rstest;
105    use vortex_array::Canonical;
106    use vortex_array::IntoArray;
107    use vortex_array::LEGACY_SESSION;
108    use vortex_array::VortexSessionExecute;
109    use vortex_array::arrays::PrimitiveArray;
110    use vortex_array::dtype::Nullability;
111
112    use crate::Sequence;
113    use crate::SequenceArray;
114
115    #[rstest]
116    #[case::basic_sequence(Sequence::try_new_typed(
117        0i32,
118        1i32,
119        Nullability::NonNullable,
120        10
121    ).unwrap())]
122    #[case::sequence_with_multiplier(Sequence::try_new_typed(
123        10i32,
124        5i32,
125        Nullability::Nullable,
126        20
127    ).unwrap())]
128    #[case::sequence_i64(Sequence::try_new_typed(
129        100i64,
130        10i64,
131        Nullability::NonNullable,
132        50
133    ).unwrap())]
134    #[case::sequence_u32(Sequence::try_new_typed(
135        0u32,
136        2u32,
137        Nullability::NonNullable,
138        100
139    ).unwrap())]
140    #[case::sequence_negative_step(Sequence::try_new_typed(
141        1000i32,
142        -10i32,
143        Nullability::Nullable,
144        30
145    ).unwrap())]
146    #[case::sequence_constant(Sequence::try_new_typed(
147        42i32,
148        0i32,  // multiplier of 0 means all values are the same
149        Nullability::Nullable,
150        15
151    ).unwrap())]
152    #[case::sequence_i16(Sequence::try_new_typed(
153        -100i16,
154        3i16,
155        Nullability::NonNullable,
156        25
157    ).unwrap())]
158    #[case::sequence_large(Sequence::try_new_typed(
159        0i64,
160        1i64,
161        Nullability::Nullable,
162        1000
163    ).unwrap())]
164    fn test_take_conformance(#[case] sequence: SequenceArray) {
165        use vortex_array::compute::conformance::take::test_take_conformance;
166        test_take_conformance(&sequence.into_array());
167    }
168
169    #[test]
170    #[should_panic(expected = "out of bounds")]
171    fn test_bounds_check() {
172        let array = Sequence::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
173        let indices = PrimitiveArray::from_iter([0i32, 20]);
174        let _array = array
175            .take(indices.into_array())
176            .unwrap()
177            .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
178            .unwrap();
179    }
180}