vortex_array/arrays/decimal/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::IntegerPType;
6use vortex_dtype::NativeDecimalType;
7use vortex_dtype::match_each_decimal_value_type;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ToCanonical;
14use crate::arrays::DecimalArray;
15use crate::arrays::DecimalVTable;
16use crate::compute::TakeKernel;
17use crate::compute::TakeKernelAdapter;
18use crate::register_kernel;
19use crate::vtable::ValidityHelper;
20
21impl TakeKernel for DecimalVTable {
22    fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23        let indices = indices.to_primitive();
24        let validity = array.validity().take(indices.as_ref())?;
25
26        // TODO(joe): if the true count of take indices validity is low, only take array values with
27        // valid indices.
28        let decimal = match_each_decimal_value_type!(array.values_type(), |D| {
29            match_each_integer_ptype!(indices.ptype(), |I| {
30                let buffer =
31                    take_to_buffer::<I, D>(indices.as_slice::<I>(), array.buffer::<D>().as_slice());
32                // SAFETY: Take operation preserves decimal dtype and creates valid buffer.
33                // Validity is computed correctly from the parent array and indices.
34                unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) }
35            })
36        });
37
38        Ok(decimal.to_array())
39    }
40}
41
42register_kernel!(TakeKernelAdapter(DecimalVTable).lift());
43
44#[inline]
45fn take_to_buffer<I: IntegerPType, T: NativeDecimalType>(indices: &[I], values: &[T]) -> Buffer<T> {
46    indices.iter().map(|idx| values[idx.as_()]).collect()
47}
48
49#[cfg(test)]
50mod tests {
51    use rstest::rstest;
52    use vortex_buffer::Buffer;
53    use vortex_buffer::buffer;
54    use vortex_dtype::DecimalDType;
55    use vortex_dtype::Nullability;
56    use vortex_scalar::DecimalValue;
57    use vortex_scalar::Scalar;
58
59    use crate::IntoArray;
60    use crate::arrays::DecimalArray;
61    use crate::arrays::DecimalVTable;
62    use crate::arrays::PrimitiveArray;
63    use crate::compute::conformance::take::test_take_conformance;
64    use crate::compute::take;
65    use crate::validity::Validity;
66
67    #[test]
68    fn test_take() {
69        let array = DecimalArray::new(
70            buffer![10i128, 11i128, 12i128, 13i128],
71            DecimalDType::new(19, 1),
72            Validity::NonNullable,
73        );
74
75        let indices = buffer![0, 2, 3].into_array();
76        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
77        let taken_decimals = taken.as_::<DecimalVTable>();
78        assert_eq!(
79            taken_decimals.buffer::<i128>(),
80            buffer![10i128, 12i128, 13i128]
81        );
82        assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
83    }
84
85    #[test]
86    fn test_take_null_indices() {
87        let array = DecimalArray::new(
88            buffer![i128::MAX, 11i128, 12i128, 13i128],
89            DecimalDType::new(19, 1),
90            Validity::NonNullable,
91        );
92
93        let indices = PrimitiveArray::from_option_iter([None, Some(2), Some(3)]).into_array();
94        let taken = take(array.as_ref(), indices.as_ref()).unwrap();
95
96        assert!(taken.scalar_at(0).is_null());
97        assert_eq!(
98            taken.scalar_at(1),
99            Scalar::decimal(
100                DecimalValue::I128(12i128),
101                array.decimal_dtype(),
102                Nullability::Nullable
103            )
104        );
105
106        assert_eq!(
107            taken.scalar_at(2),
108            Scalar::decimal(
109                DecimalValue::I128(13i128),
110                array.decimal_dtype(),
111                Nullability::Nullable
112            )
113        );
114    }
115
116    #[rstest]
117    #[case(DecimalArray::new(
118        buffer![100i128, 200i128, 300i128, 400i128, 500i128],
119        DecimalDType::new(19, 2),
120        Validity::NonNullable,
121    ))]
122    #[case(DecimalArray::new(
123        buffer![10i64, 20i64, 30i64, 40i64, 50i64],
124        DecimalDType::new(10, 1),
125        Validity::NonNullable,
126    ))]
127    #[case(DecimalArray::new(
128        buffer![1i32, 2i32, 3i32, 4i32, 5i32],
129        DecimalDType::new(5, 0),
130        Validity::NonNullable,
131    ))]
132    #[case(DecimalArray::new(
133        buffer![1000i128, 2000i128, 3000i128, 4000i128, 5000i128],
134        DecimalDType::new(19, 3),
135        Validity::from_iter([true, false, true, true, false]),
136    ))]
137    #[case(DecimalArray::new(
138        buffer![42i128],
139        DecimalDType::new(19, 0),
140        Validity::NonNullable,
141    ))]
142    #[case({
143        let values: Vec<i128> = (0..100).map(|i| i * 1000).collect();
144        DecimalArray::new(
145            Buffer::from_iter(values),
146            DecimalDType::new(19, 4),
147            Validity::NonNullable,
148        )
149    })]
150    fn test_take_decimal_conformance(#[case] array: DecimalArray) {
151        test_take_conformance(array.as_ref());
152    }
153}