vortex_array/arrays/decimal/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::{
6    IntegerPType, NativeDecimalType, match_each_decimal_value_type, match_each_integer_ptype,
7};
8use vortex_error::VortexResult;
9
10use crate::arrays::{DecimalArray, DecimalVTable};
11use crate::compute::{TakeKernel, TakeKernelAdapter};
12use crate::vtable::ValidityHelper;
13use crate::{Array, ArrayRef, ToCanonical, register_kernel};
14
15impl TakeKernel for DecimalVTable {
16    fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17        let indices = indices.to_primitive();
18        let validity = array.validity().take(indices.as_ref())?;
19
20        // TODO(joe): if the true count of take indices validity is low, only take array values with
21        // valid indices.
22        let decimal = match_each_decimal_value_type!(array.values_type(), |D| {
23            match_each_integer_ptype!(indices.ptype(), |I| {
24                let buffer =
25                    take_to_buffer::<I, D>(indices.as_slice::<I>(), array.buffer::<D>().as_slice());
26                // SAFETY: Take operation preserves decimal dtype and creates valid buffer.
27                // Validity is computed correctly from the parent array and indices.
28                unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) }
29            })
30        });
31
32        Ok(decimal.to_array())
33    }
34}
35
36register_kernel!(TakeKernelAdapter(DecimalVTable).lift());
37
38#[inline]
39fn take_to_buffer<I: IntegerPType, T: NativeDecimalType>(indices: &[I], values: &[T]) -> Buffer<T> {
40    indices.iter().map(|idx| values[idx.as_()]).collect()
41}
42
43#[cfg(test)]
44mod tests {
45    use rstest::rstest;
46    use vortex_buffer::{Buffer, buffer};
47    use vortex_dtype::{DecimalDType, Nullability};
48    use vortex_scalar::{DecimalValue, Scalar};
49
50    use crate::IntoArray;
51    use crate::arrays::{DecimalArray, DecimalVTable, PrimitiveArray};
52    use crate::compute::conformance::take::test_take_conformance;
53    use crate::compute::take;
54    use crate::validity::Validity;
55
56    #[test]
57    fn test_take() {
58        let array = DecimalArray::new(
59            buffer![10i128, 11i128, 12i128, 13i128],
60            DecimalDType::new(19, 1),
61            Validity::NonNullable,
62        );
63
64        let indices = buffer![0, 2, 3].into_array();
65        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
66        let taken_decimals = taken.as_::<DecimalVTable>();
67        assert_eq!(
68            taken_decimals.buffer::<i128>(),
69            buffer![10i128, 12i128, 13i128]
70        );
71        assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
72    }
73
74    #[test]
75    fn test_take_null_indices() {
76        let array = DecimalArray::new(
77            buffer![i128::MAX, 11i128, 12i128, 13i128],
78            DecimalDType::new(19, 1),
79            Validity::NonNullable,
80        );
81
82        let indices = PrimitiveArray::from_option_iter([None, Some(2), Some(3)]).into_array();
83        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
84
85        assert!(taken.scalar_at(0).is_null());
86        assert_eq!(
87            taken.scalar_at(1),
88            Scalar::decimal(
89                DecimalValue::I128(12i128),
90                array.decimal_dtype(),
91                Nullability::Nullable
92            )
93        );
94
95        assert_eq!(
96            taken.scalar_at(2),
97            Scalar::decimal(
98                DecimalValue::I128(13i128),
99                array.decimal_dtype(),
100                Nullability::Nullable
101            )
102        );
103    }
104
105    #[rstest]
106    #[case(DecimalArray::new(
107        buffer![100i128, 200i128, 300i128, 400i128, 500i128],
108        DecimalDType::new(19, 2),
109        Validity::NonNullable,
110    ))]
111    #[case(DecimalArray::new(
112        buffer![10i64, 20i64, 30i64, 40i64, 50i64],
113        DecimalDType::new(10, 1),
114        Validity::NonNullable,
115    ))]
116    #[case(DecimalArray::new(
117        buffer![1i32, 2i32, 3i32, 4i32, 5i32],
118        DecimalDType::new(5, 0),
119        Validity::NonNullable,
120    ))]
121    #[case(DecimalArray::new(
122        buffer![1000i128, 2000i128, 3000i128, 4000i128, 5000i128],
123        DecimalDType::new(19, 3),
124        Validity::from_iter([true, false, true, true, false]),
125    ))]
126    #[case(DecimalArray::new(
127        buffer![42i128],
128        DecimalDType::new(19, 0),
129        Validity::NonNullable,
130    ))]
131    #[case({
132        let values: Vec<i128> = (0..100).map(|i| i * 1000).collect();
133        DecimalArray::new(
134            Buffer::from_iter(values),
135            DecimalDType::new(19, 4),
136            Validity::NonNullable,
137        )
138    })]
139    fn test_take_decimal_conformance(#[case] array: DecimalArray) {
140        test_take_conformance(array.as_ref());
141    }
142}