vortex_sequence/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::cast::NumCast;
5use vortex_array::Array;
6use vortex_array::ArrayRef;
7use vortex_array::IntoArray;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::compute::TakeKernel;
12use vortex_array::compute::TakeKernelAdapter;
13use vortex_array::register_kernel;
14use vortex_array::validity::Validity;
15use vortex_buffer::Buffer;
16use vortex_dtype::DType;
17use vortex_dtype::IntegerPType;
18use vortex_dtype::NativePType;
19use vortex_dtype::Nullability;
20use vortex_dtype::match_each_integer_ptype;
21use vortex_dtype::match_each_native_ptype;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_error::vortex_panic;
25use vortex_mask::AllOr;
26use vortex_mask::Mask;
27use vortex_scalar::Scalar;
28
29use crate::SequenceArray;
30use crate::SequenceVTable;
31
32impl TakeKernel for SequenceVTable {
33    fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
34        let mask = indices.validity_mask();
35        let indices = indices.to_primitive();
36        let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
37
38        match_each_integer_ptype!(indices.ptype(), |T| {
39            let indices = indices.as_slice::<T>();
40            match_each_native_ptype!(array.ptype(), |S| {
41                let mul = array.multiplier().cast::<S>();
42                let base = array.base().cast::<S>();
43                Ok(take(
44                    mul,
45                    base,
46                    indices,
47                    mask,
48                    result_nullability,
49                    array.len(),
50                ))
51            })
52        })
53    }
54}
55
56fn take<T: IntegerPType, S: NativePType>(
57    mul: S,
58    base: S,
59    indices: &[T],
60    indices_mask: Mask,
61    result_nullability: Nullability,
62    len: usize,
63) -> ArrayRef {
64    match indices_mask.bit_buffer() {
65        AllOr::All => PrimitiveArray::new(
66            Buffer::from_trusted_len_iter(indices.iter().map(|i| {
67                if i.as_() >= len {
68                    vortex_panic!(OutOfBounds: i.as_(), 0, len);
69                }
70                let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
71                base + i * mul
72            })),
73            Validity::from(result_nullability),
74        )
75        .into_array(),
76        AllOr::None => ConstantArray::new(
77            Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
78            indices.len(),
79        )
80        .into_array(),
81        AllOr::Some(b) => {
82            let buffer =
83                Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
84                    if b.value(mask_index) {
85                        if i.as_() >= len {
86                            vortex_panic!(OutOfBounds: i.as_(), 0, len);
87                        }
88
89                        let i =
90                            <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
91                        base + i * mul
92                    } else {
93                        S::zero()
94                    }
95                }));
96            PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
97        }
98    }
99}
100
101register_kernel!(TakeKernelAdapter(SequenceVTable).lift());
102
103#[cfg(test)]
104mod test {
105    use rstest::rstest;
106    use vortex_array::compute::take;
107    use vortex_dtype::Nullability;
108
109    use crate::SequenceArray;
110
111    #[rstest]
112    #[case::basic_sequence(SequenceArray::typed_new(
113        0i32,
114        1i32,
115        Nullability::NonNullable,
116        10
117    ).unwrap())]
118    #[case::sequence_with_multiplier(SequenceArray::typed_new(
119        10i32,
120        5i32,
121        Nullability::Nullable,
122        20
123    ).unwrap())]
124    #[case::sequence_i64(SequenceArray::typed_new(
125        100i64,
126        10i64,
127        Nullability::NonNullable,
128        50
129    ).unwrap())]
130    #[case::sequence_u32(SequenceArray::typed_new(
131        0u32,
132        2u32,
133        Nullability::NonNullable,
134        100
135    ).unwrap())]
136    #[case::sequence_negative_step(SequenceArray::typed_new(
137        1000i32,
138        -10i32,
139        Nullability::Nullable,
140        30
141    ).unwrap())]
142    #[case::sequence_constant(SequenceArray::typed_new(
143        42i32,
144        0i32,  // multiplier of 0 means all values are the same
145        Nullability::Nullable,
146        15
147    ).unwrap())]
148    #[case::sequence_i16(SequenceArray::typed_new(
149        -100i16,
150        3i16,
151        Nullability::NonNullable,
152        25
153    ).unwrap())]
154    #[case::sequence_large(SequenceArray::typed_new(
155        0i64,
156        1i64,
157        Nullability::Nullable,
158        1000
159    ).unwrap())]
160    fn test_take_conformance(#[case] sequence: SequenceArray) {
161        use vortex_array::compute::conformance::take::test_take_conformance;
162        test_take_conformance(sequence.as_ref());
163    }
164
165    #[test]
166    #[should_panic(expected = "index 20 out of bounds")]
167    fn test_bounds_check() {
168        let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
169        let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
170        let _array = take(array.as_ref(), indices.as_ref()).unwrap();
171    }
172}