vortex_sequence/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::cast::NumCast;
5use vortex_array::arrays::{ConstantArray, PrimitiveArray};
6use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_buffer::Buffer;
10use vortex_dtype::{
11    DType, IntegerPType, NativePType, Nullability, match_each_integer_ptype,
12    match_each_native_ptype,
13};
14use vortex_error::{VortexExpect, VortexResult};
15use vortex_mask::{AllOr, Mask};
16use vortex_scalar::Scalar;
17
18use crate::{SequenceArray, SequenceVTable};
19
20impl TakeKernel for SequenceVTable {
21    fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22        let mask = indices.validity_mask();
23        let indices = indices.to_primitive();
24        let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
25
26        Ok(match_each_integer_ptype!(indices.ptype(), |T| {
27            let indices = indices.as_slice::<T>();
28            match_each_native_ptype!(array.ptype(), |S| {
29                let mul = array.multiplier().as_primitive::<S>();
30                let base = array.base().as_primitive::<S>();
31                take(mul, base, indices, mask, result_nullability)
32            })
33        }))
34    }
35}
36
37fn take<T: IntegerPType, S: NativePType>(
38    mul: S,
39    base: S,
40    indices: &[T],
41    indices_mask: Mask,
42    result_nullability: Nullability,
43) -> ArrayRef {
44    match indices_mask.boolean_buffer() {
45        AllOr::All => PrimitiveArray::new(
46            Buffer::from_trusted_len_iter(indices.iter().map(|i| {
47                let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
48                base + i * mul
49            })),
50            Validity::from(result_nullability),
51        )
52        .into_array(),
53        AllOr::None => ConstantArray::new(
54            Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
55            indices.len(),
56        )
57        .into_array(),
58        AllOr::Some(b) => {
59            let buffer =
60                Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
61                    if b.value(mask_index) {
62                        let i =
63                            <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
64                        base + i * mul
65                    } else {
66                        S::zero()
67                    }
68                }));
69            PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
70        }
71    }
72}
73
74register_kernel!(TakeKernelAdapter(SequenceVTable).lift());
75
76#[cfg(test)]
77mod test {
78    use rstest::rstest;
79    use vortex_dtype::Nullability;
80
81    use crate::SequenceArray;
82
83    #[rstest]
84    #[case::basic_sequence(SequenceArray::typed_new(
85        0i32,
86        1i32,
87        Nullability::NonNullable,
88        10
89    ).unwrap())]
90    #[case::sequence_with_multiplier(SequenceArray::typed_new(
91        10i32,
92        5i32,
93        Nullability::Nullable,
94        20
95    ).unwrap())]
96    #[case::sequence_i64(SequenceArray::typed_new(
97        100i64,
98        10i64,
99        Nullability::NonNullable,
100        50
101    ).unwrap())]
102    #[case::sequence_u32(SequenceArray::typed_new(
103        0u32,
104        2u32,
105        Nullability::NonNullable,
106        100
107    ).unwrap())]
108    #[case::sequence_negative_step(SequenceArray::typed_new(
109        1000i32,
110        -10i32,
111        Nullability::Nullable,
112        30
113    ).unwrap())]
114    #[case::sequence_constant(SequenceArray::typed_new(
115        42i32,
116        0i32,  // multiplier of 0 means all values are the same
117        Nullability::Nullable,
118        15
119    ).unwrap())]
120    #[case::sequence_i16(SequenceArray::typed_new(
121        -100i16,
122        3i16,
123        Nullability::NonNullable,
124        25
125    ).unwrap())]
126    #[case::sequence_large(SequenceArray::typed_new(
127        0i64,
128        1i64,
129        Nullability::Nullable,
130        1000
131    ).unwrap())]
132    fn test_take_conformance(#[case] sequence: SequenceArray) {
133        use vortex_array::compute::conformance::take::test_take_conformance;
134        test_take_conformance(sequence.as_ref());
135    }
136}