vortex_sequence/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::cast::NumCast;
5use vortex_array::arrays::{ConstantArray, PrimitiveArray};
6use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_buffer::Buffer;
10use vortex_dtype::{
11    DType, IntegerPType, NativePType, Nullability, match_each_integer_ptype,
12    match_each_native_ptype,
13};
14use vortex_error::{VortexExpect, VortexResult, vortex_panic};
15use vortex_mask::{AllOr, Mask};
16use vortex_scalar::Scalar;
17
18use crate::{SequenceArray, SequenceVTable};
19
20impl TakeKernel for SequenceVTable {
21    fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22        let mask = indices.validity_mask();
23        let indices = indices.to_primitive();
24        let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
25
26        match_each_integer_ptype!(indices.ptype(), |T| {
27            let indices = indices.as_slice::<T>();
28            match_each_native_ptype!(array.ptype(), |S| {
29                let mul = array.multiplier().cast::<S>();
30                let base = array.base().cast::<S>();
31                Ok(take(
32                    mul,
33                    base,
34                    indices,
35                    mask,
36                    result_nullability,
37                    array.len(),
38                ))
39            })
40        })
41    }
42}
43
44fn take<T: IntegerPType, S: NativePType>(
45    mul: S,
46    base: S,
47    indices: &[T],
48    indices_mask: Mask,
49    result_nullability: Nullability,
50    len: usize,
51) -> ArrayRef {
52    match indices_mask.bit_buffer() {
53        AllOr::All => PrimitiveArray::new(
54            Buffer::from_trusted_len_iter(indices.iter().map(|i| {
55                if i.as_() >= len {
56                    vortex_panic!(OutOfBounds: i.as_(), 0, len);
57                }
58                let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
59                base + i * mul
60            })),
61            Validity::from(result_nullability),
62        )
63        .into_array(),
64        AllOr::None => ConstantArray::new(
65            Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
66            indices.len(),
67        )
68        .into_array(),
69        AllOr::Some(b) => {
70            let buffer =
71                Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
72                    if b.value(mask_index) {
73                        if i.as_() >= len {
74                            vortex_panic!(OutOfBounds: i.as_(), 0, len);
75                        }
76
77                        let i =
78                            <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
79                        base + i * mul
80                    } else {
81                        S::zero()
82                    }
83                }));
84            PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
85        }
86    }
87}
88
89register_kernel!(TakeKernelAdapter(SequenceVTable).lift());
90
91#[cfg(test)]
92mod test {
93    use rstest::rstest;
94    use vortex_array::compute::take;
95    use vortex_dtype::Nullability;
96
97    use crate::SequenceArray;
98
99    #[rstest]
100    #[case::basic_sequence(SequenceArray::typed_new(
101        0i32,
102        1i32,
103        Nullability::NonNullable,
104        10
105    ).unwrap())]
106    #[case::sequence_with_multiplier(SequenceArray::typed_new(
107        10i32,
108        5i32,
109        Nullability::Nullable,
110        20
111    ).unwrap())]
112    #[case::sequence_i64(SequenceArray::typed_new(
113        100i64,
114        10i64,
115        Nullability::NonNullable,
116        50
117    ).unwrap())]
118    #[case::sequence_u32(SequenceArray::typed_new(
119        0u32,
120        2u32,
121        Nullability::NonNullable,
122        100
123    ).unwrap())]
124    #[case::sequence_negative_step(SequenceArray::typed_new(
125        1000i32,
126        -10i32,
127        Nullability::Nullable,
128        30
129    ).unwrap())]
130    #[case::sequence_constant(SequenceArray::typed_new(
131        42i32,
132        0i32,  // multiplier of 0 means all values are the same
133        Nullability::Nullable,
134        15
135    ).unwrap())]
136    #[case::sequence_i16(SequenceArray::typed_new(
137        -100i16,
138        3i16,
139        Nullability::NonNullable,
140        25
141    ).unwrap())]
142    #[case::sequence_large(SequenceArray::typed_new(
143        0i64,
144        1i64,
145        Nullability::Nullable,
146        1000
147    ).unwrap())]
148    fn test_take_conformance(#[case] sequence: SequenceArray) {
149        use vortex_array::compute::conformance::take::test_take_conformance;
150        test_take_conformance(sequence.as_ref());
151    }
152
153    #[test]
154    #[should_panic(expected = "index 20 out of bounds")]
155    fn test_bounds_check() {
156        let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
157        let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
158        let _array = take(array.as_ref(), indices.as_ref()).unwrap();
159    }
160}