vortex_array/arrays/decimal/compute/
take.rs

1use num_traits::AsPrimitive;
2use vortex_buffer::Buffer;
3use vortex_dtype::{NativePType, match_each_integer_ptype};
4use vortex_error::VortexResult;
5use vortex_scalar::{NativeDecimalType, match_each_decimal_value_type};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{TakeKernel, TakeKernelAdapter};
9use crate::vtable::ValidityHelper;
10use crate::{Array, ArrayRef, ToCanonical, register_kernel};
11
12impl TakeKernel for DecimalVTable {
13    fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
14        let indices = indices.to_primitive()?;
15        let validity = array.validity().take(indices.as_ref())?;
16
17        // TODO(joe): if the true count of take indices validity is low, only take array values with
18        // valid indices.
19        let decimal = match_each_decimal_value_type!(array.values_type(), |D| {
20            match_each_integer_ptype!(indices.ptype(), |I| {
21                let buffer =
22                    take_to_buffer::<I, D>(indices.as_slice::<I>(), array.buffer::<D>().as_slice());
23                DecimalArray::new(buffer, array.decimal_dtype(), validity)
24            })
25        });
26
27        Ok(decimal.to_array())
28    }
29}
30
31register_kernel!(TakeKernelAdapter(DecimalVTable).lift());
32
33#[inline]
34fn take_to_buffer<I: NativePType + AsPrimitive<usize>, T: NativeDecimalType>(
35    indices: &[I],
36    values: &[T],
37) -> Buffer<T> {
38    indices.iter().map(|idx| values[idx.as_()]).collect()
39}
40
41#[cfg(test)]
42mod tests {
43    use vortex_buffer::buffer;
44    use vortex_dtype::DecimalDType;
45
46    use crate::IntoArray;
47    use crate::arrays::{DecimalArray, DecimalVTable};
48    use crate::compute::take;
49    use crate::validity::Validity;
50
51    #[test]
52    fn test_take() {
53        let array = DecimalArray::new(
54            buffer![10i128, 11i128, 12i128, 13i128],
55            DecimalDType::new(19, 1),
56            Validity::NonNullable,
57        );
58
59        let indices = buffer![0, 2, 3].into_array();
60        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
61        let taken_decimals = taken.as_::<DecimalVTable>();
62        assert_eq!(
63            taken_decimals.buffer::<i128>(),
64            buffer![10i128, 12i128, 13i128]
65        );
66        assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
67    }
68}