vortex_array/arrays/decimal/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_dtype::{NativePType, match_each_integer_ptype};
7use vortex_error::VortexResult;
8use vortex_scalar::{NativeDecimalType, match_each_decimal_value_type};
9
10use crate::arrays::{DecimalArray, DecimalVTable};
11use crate::compute::{TakeKernel, TakeKernelAdapter};
12use crate::vtable::ValidityHelper;
13use crate::{Array, ArrayRef, ToCanonical, register_kernel};
14
15impl TakeKernel for DecimalVTable {
16    fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17        let indices = indices.to_primitive()?;
18        let validity = array.validity().take(indices.as_ref())?;
19
20        // TODO(joe): if the true count of take indices validity is low, only take array values with
21        // valid indices.
22        let decimal = match_each_decimal_value_type!(array.values_type(), |D| {
23            match_each_integer_ptype!(indices.ptype(), |I| {
24                let buffer =
25                    take_to_buffer::<I, D>(indices.as_slice::<I>(), array.buffer::<D>().as_slice());
26                DecimalArray::new(buffer, array.decimal_dtype(), validity)
27            })
28        });
29
30        Ok(decimal.to_array())
31    }
32}
33
34register_kernel!(TakeKernelAdapter(DecimalVTable).lift());
35
36#[inline]
37fn take_to_buffer<I: NativePType + AsPrimitive<usize>, T: NativeDecimalType>(
38    indices: &[I],
39    values: &[T],
40) -> Buffer<T> {
41    indices.iter().map(|idx| values[idx.as_()]).collect()
42}
43
44#[cfg(test)]
45mod tests {
46    use rstest::rstest;
47    use vortex_buffer::{Buffer, buffer};
48    use vortex_dtype::{DecimalDType, Nullability};
49    use vortex_scalar::{DecimalValue, Scalar};
50
51    use crate::IntoArray;
52    use crate::arrays::{DecimalArray, DecimalVTable, PrimitiveArray};
53    use crate::compute::conformance::take::test_take_conformance;
54    use crate::compute::take;
55    use crate::validity::Validity;
56
57    #[test]
58    fn test_take() {
59        let array = DecimalArray::new(
60            buffer![10i128, 11i128, 12i128, 13i128],
61            DecimalDType::new(19, 1),
62            Validity::NonNullable,
63        );
64
65        let indices = buffer![0, 2, 3].into_array();
66        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
67        let taken_decimals = taken.as_::<DecimalVTable>();
68        assert_eq!(
69            taken_decimals.buffer::<i128>(),
70            buffer![10i128, 12i128, 13i128]
71        );
72        assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
73    }
74
75    #[test]
76    fn test_take_null_indices() {
77        let array = DecimalArray::new(
78            buffer![i128::MAX, 11i128, 12i128, 13i128],
79            DecimalDType::new(19, 1),
80            Validity::NonNullable,
81        );
82
83        let indices = PrimitiveArray::from_option_iter([None, Some(2), Some(3)]).into_array();
84        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
85
86        assert!(taken.scalar_at(0).is_null());
87        assert_eq!(
88            taken.scalar_at(1),
89            Scalar::decimal(
90                DecimalValue::I128(12i128),
91                array.decimal_dtype(),
92                Nullability::Nullable
93            )
94        );
95
96        assert_eq!(
97            taken.scalar_at(2),
98            Scalar::decimal(
99                DecimalValue::I128(13i128),
100                array.decimal_dtype(),
101                Nullability::Nullable
102            )
103        );
104    }
105
106    #[rstest]
107    #[case(DecimalArray::new(
108        buffer![100i128, 200i128, 300i128, 400i128, 500i128],
109        DecimalDType::new(19, 2),
110        Validity::NonNullable,
111    ))]
112    #[case(DecimalArray::new(
113        buffer![10i64, 20i64, 30i64, 40i64, 50i64],
114        DecimalDType::new(10, 1),
115        Validity::NonNullable,
116    ))]
117    #[case(DecimalArray::new(
118        buffer![1i32, 2i32, 3i32, 4i32, 5i32],
119        DecimalDType::new(5, 0),
120        Validity::NonNullable,
121    ))]
122    #[case(DecimalArray::new(
123        buffer![1000i128, 2000i128, 3000i128, 4000i128, 5000i128],
124        DecimalDType::new(19, 3),
125        Validity::from_iter([true, false, true, true, false]),
126    ))]
127    #[case(DecimalArray::new(
128        buffer![42i128],
129        DecimalDType::new(19, 0),
130        Validity::NonNullable,
131    ))]
132    #[case({
133        let values: Vec<i128> = (0..100).map(|i| i * 1000).collect();
134        DecimalArray::new(
135            Buffer::from_iter(values),
136            DecimalDType::new(19, 4),
137            Validity::NonNullable,
138        )
139    })]
140    fn test_take_decimal_conformance(#[case] array: DecimalArray) {
141        test_take_conformance(array.as_ref());
142    }
143}