Skip to main content

vortex_sequence/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::cast::NumCast;
5use vortex_array::ArrayRef;
6use vortex_array::DynArray;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::dict::TakeExecute;
12use vortex_array::dtype::DType;
13use vortex_array::dtype::IntegerPType;
14use vortex_array::dtype::NativePType;
15use vortex_array::dtype::Nullability;
16use vortex_array::match_each_integer_ptype;
17use vortex_array::match_each_native_ptype;
18use vortex_array::scalar::Scalar;
19use vortex_array::validity::Validity;
20use vortex_buffer::Buffer;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_panic;
24use vortex_mask::AllOr;
25use vortex_mask::Mask;
26
27use crate::Sequence;
28use crate::SequenceArray;
29
30fn take_inner<T: IntegerPType, S: NativePType>(
31    mul: S,
32    base: S,
33    indices: &[T],
34    indices_mask: Mask,
35    result_nullability: Nullability,
36    len: usize,
37) -> ArrayRef {
38    match indices_mask.bit_buffer() {
39        AllOr::All => PrimitiveArray::new(
40            Buffer::from_trusted_len_iter(indices.iter().map(|i| {
41                if i.as_() >= len {
42                    vortex_panic!(OutOfBounds: i.as_(), 0, len);
43                }
44                let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
45                base + i * mul
46            })),
47            Validity::from(result_nullability),
48        )
49        .into_array(),
50        AllOr::None => ConstantArray::new(
51            Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
52            indices.len(),
53        )
54        .into_array(),
55        AllOr::Some(b) => {
56            let buffer =
57                Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
58                    if b.value(mask_index) {
59                        if i.as_() >= len {
60                            vortex_panic!(OutOfBounds: i.as_(), 0, len);
61                        }
62
63                        let i =
64                            <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
65                        base + i * mul
66                    } else {
67                        S::zero()
68                    }
69                }));
70            PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
71        }
72    }
73}
74
75impl TakeExecute for Sequence {
76    fn take(
77        array: &SequenceArray,
78        indices: &ArrayRef,
79        ctx: &mut ExecutionCtx,
80    ) -> VortexResult<Option<ArrayRef>> {
81        let mask = indices.validity_mask()?;
82        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
83        let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
84
85        match_each_integer_ptype!(indices.ptype(), |T| {
86            let indices = indices.as_slice::<T>();
87            match_each_native_ptype!(array.ptype(), |S| {
88                let mul = array.multiplier().cast::<S>()?;
89                let base = array.base().cast::<S>()?;
90                Ok(Some(take_inner(
91                    mul,
92                    base,
93                    indices,
94                    mask,
95                    result_nullability,
96                    array.len(),
97                )))
98            })
99        })
100    }
101}
102
103#[cfg(test)]
104mod test {
105    use rstest::rstest;
106    use vortex_array::Canonical;
107    use vortex_array::IntoArray;
108    use vortex_array::LEGACY_SESSION;
109    use vortex_array::VortexSessionExecute;
110    use vortex_array::arrays::PrimitiveArray;
111    use vortex_array::dtype::Nullability;
112
113    use crate::SequenceArray;
114
115    #[rstest]
116    #[case::basic_sequence(SequenceArray::try_new_typed(
117        0i32,
118        1i32,
119        Nullability::NonNullable,
120        10
121    ).unwrap())]
122    #[case::sequence_with_multiplier(SequenceArray::try_new_typed(
123        10i32,
124        5i32,
125        Nullability::Nullable,
126        20
127    ).unwrap())]
128    #[case::sequence_i64(SequenceArray::try_new_typed(
129        100i64,
130        10i64,
131        Nullability::NonNullable,
132        50
133    ).unwrap())]
134    #[case::sequence_u32(SequenceArray::try_new_typed(
135        0u32,
136        2u32,
137        Nullability::NonNullable,
138        100
139    ).unwrap())]
140    #[case::sequence_negative_step(SequenceArray::try_new_typed(
141        1000i32,
142        -10i32,
143        Nullability::Nullable,
144        30
145    ).unwrap())]
146    #[case::sequence_constant(SequenceArray::try_new_typed(
147        42i32,
148        0i32,  // multiplier of 0 means all values are the same
149        Nullability::Nullable,
150        15
151    ).unwrap())]
152    #[case::sequence_i16(SequenceArray::try_new_typed(
153        -100i16,
154        3i16,
155        Nullability::NonNullable,
156        25
157    ).unwrap())]
158    #[case::sequence_large(SequenceArray::try_new_typed(
159        0i64,
160        1i64,
161        Nullability::Nullable,
162        1000
163    ).unwrap())]
164    fn test_take_conformance(#[case] sequence: SequenceArray) {
165        use vortex_array::compute::conformance::take::test_take_conformance;
166        test_take_conformance(&sequence.into_array());
167    }
168
169    #[test]
170    #[should_panic(expected = "out of bounds")]
171    fn test_bounds_check() {
172        let array = SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
173        let indices = PrimitiveArray::from_iter([0i32, 20]);
174        let _array = array
175            .take(indices.into_array())
176            .unwrap()
177            .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
178            .unwrap();
179    }
180}