Skip to main content

vortex_sequence/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::cast::NumCast;
5use vortex_array::Array;
6use vortex_array::ArrayRef;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::arrays::PrimitiveArray;
12use vortex_array::arrays::TakeExecute;
13use vortex_array::dtype::DType;
14use vortex_array::dtype::IntegerPType;
15use vortex_array::dtype::NativePType;
16use vortex_array::dtype::Nullability;
17use vortex_array::match_each_integer_ptype;
18use vortex_array::match_each_native_ptype;
19use vortex_array::scalar::Scalar;
20use vortex_array::validity::Validity;
21use vortex_buffer::Buffer;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_error::vortex_panic;
25use vortex_mask::AllOr;
26use vortex_mask::Mask;
27
28use crate::SequenceArray;
29use crate::SequenceVTable;
30
31fn take_inner<T: IntegerPType, S: NativePType>(
32    mul: S,
33    base: S,
34    indices: &[T],
35    indices_mask: Mask,
36    result_nullability: Nullability,
37    len: usize,
38) -> ArrayRef {
39    match indices_mask.bit_buffer() {
40        AllOr::All => PrimitiveArray::new(
41            Buffer::from_trusted_len_iter(indices.iter().map(|i| {
42                if i.as_() >= len {
43                    vortex_panic!(OutOfBounds: i.as_(), 0, len);
44                }
45                let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
46                base + i * mul
47            })),
48            Validity::from(result_nullability),
49        )
50        .into_array(),
51        AllOr::None => ConstantArray::new(
52            Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
53            indices.len(),
54        )
55        .into_array(),
56        AllOr::Some(b) => {
57            let buffer =
58                Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
59                    if b.value(mask_index) {
60                        if i.as_() >= len {
61                            vortex_panic!(OutOfBounds: i.as_(), 0, len);
62                        }
63
64                        let i =
65                            <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
66                        base + i * mul
67                    } else {
68                        S::zero()
69                    }
70                }));
71            PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
72        }
73    }
74}
75
76impl TakeExecute for SequenceVTable {
77    fn take(
78        array: &SequenceArray,
79        indices: &ArrayRef,
80        _ctx: &mut ExecutionCtx,
81    ) -> VortexResult<Option<ArrayRef>> {
82        let mask = indices.validity_mask()?;
83        let indices = indices.to_primitive();
84        let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
85
86        match_each_integer_ptype!(indices.ptype(), |T| {
87            let indices = indices.as_slice::<T>();
88            match_each_native_ptype!(array.ptype(), |S| {
89                let mul = array.multiplier().cast::<S>()?;
90                let base = array.base().cast::<S>()?;
91                Ok(Some(take_inner(
92                    mul,
93                    base,
94                    indices,
95                    mask,
96                    result_nullability,
97                    array.len(),
98                )))
99            })
100        })
101    }
102}
103
104#[cfg(test)]
105mod test {
106    use rstest::rstest;
107    use vortex_array::Canonical;
108    use vortex_array::LEGACY_SESSION;
109    use vortex_array::VortexSessionExecute;
110    use vortex_array::dtype::Nullability;
111
112    use crate::SequenceArray;
113
114    #[rstest]
115    #[case::basic_sequence(SequenceArray::try_new_typed(
116        0i32,
117        1i32,
118        Nullability::NonNullable,
119        10
120    ).unwrap())]
121    #[case::sequence_with_multiplier(SequenceArray::try_new_typed(
122        10i32,
123        5i32,
124        Nullability::Nullable,
125        20
126    ).unwrap())]
127    #[case::sequence_i64(SequenceArray::try_new_typed(
128        100i64,
129        10i64,
130        Nullability::NonNullable,
131        50
132    ).unwrap())]
133    #[case::sequence_u32(SequenceArray::try_new_typed(
134        0u32,
135        2u32,
136        Nullability::NonNullable,
137        100
138    ).unwrap())]
139    #[case::sequence_negative_step(SequenceArray::try_new_typed(
140        1000i32,
141        -10i32,
142        Nullability::Nullable,
143        30
144    ).unwrap())]
145    #[case::sequence_constant(SequenceArray::try_new_typed(
146        42i32,
147        0i32,  // multiplier of 0 means all values are the same
148        Nullability::Nullable,
149        15
150    ).unwrap())]
151    #[case::sequence_i16(SequenceArray::try_new_typed(
152        -100i16,
153        3i16,
154        Nullability::NonNullable,
155        25
156    ).unwrap())]
157    #[case::sequence_large(SequenceArray::try_new_typed(
158        0i64,
159        1i64,
160        Nullability::Nullable,
161        1000
162    ).unwrap())]
163    fn test_take_conformance(#[case] sequence: SequenceArray) {
164        use vortex_array::compute::conformance::take::test_take_conformance;
165        test_take_conformance(&sequence.to_array());
166    }
167
168    #[test]
169    #[should_panic(expected = "out of bounds")]
170    fn test_bounds_check() {
171        let array = SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
172        let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
173        let _array = array
174            .take(indices.to_array())
175            .unwrap()
176            .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
177            .unwrap();
178    }
179}