vortex_array/arrays/decimal/compute/
take.rs1use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_dtype::{NativePType, match_each_integer_ptype};
7use vortex_error::VortexResult;
8use vortex_scalar::{NativeDecimalType, match_each_decimal_value_type};
9
10use crate::arrays::{DecimalArray, DecimalVTable};
11use crate::compute::{TakeKernel, TakeKernelAdapter};
12use crate::vtable::ValidityHelper;
13use crate::{Array, ArrayRef, ToCanonical, register_kernel};
14
15impl TakeKernel for DecimalVTable {
16 fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17 let indices = indices.to_primitive();
18 let validity = array.validity().take(indices.as_ref())?;
19
20 let decimal = match_each_decimal_value_type!(array.values_type(), |D| {
23 match_each_integer_ptype!(indices.ptype(), |I| {
24 let buffer =
25 take_to_buffer::<I, D>(indices.as_slice::<I>(), array.buffer::<D>().as_slice());
26 unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) }
29 })
30 });
31
32 Ok(decimal.to_array())
33 }
34}
35
36register_kernel!(TakeKernelAdapter(DecimalVTable).lift());
37
38#[inline]
39fn take_to_buffer<I: NativePType + AsPrimitive<usize>, T: NativeDecimalType>(
40 indices: &[I],
41 values: &[T],
42) -> Buffer<T> {
43 indices.iter().map(|idx| values[idx.as_()]).collect()
44}
45
46#[cfg(test)]
47mod tests {
48 use rstest::rstest;
49 use vortex_buffer::{Buffer, buffer};
50 use vortex_dtype::{DecimalDType, Nullability};
51 use vortex_scalar::{DecimalValue, Scalar};
52
53 use crate::IntoArray;
54 use crate::arrays::{DecimalArray, DecimalVTable, PrimitiveArray};
55 use crate::compute::conformance::take::test_take_conformance;
56 use crate::compute::take;
57 use crate::validity::Validity;
58
59 #[test]
60 fn test_take() {
61 let array = DecimalArray::new(
62 buffer![10i128, 11i128, 12i128, 13i128],
63 DecimalDType::new(19, 1),
64 Validity::NonNullable,
65 );
66
67 let indices = buffer![0, 2, 3].into_array();
68 let taken = take(array.as_ref(), indices.as_ref()).unwrap();
69 let taken_decimals = taken.as_::<DecimalVTable>();
70 assert_eq!(
71 taken_decimals.buffer::<i128>(),
72 buffer![10i128, 12i128, 13i128]
73 );
74 assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
75 }
76
77 #[test]
78 fn test_take_null_indices() {
79 let array = DecimalArray::new(
80 buffer![i128::MAX, 11i128, 12i128, 13i128],
81 DecimalDType::new(19, 1),
82 Validity::NonNullable,
83 );
84
85 let indices = PrimitiveArray::from_option_iter([None, Some(2), Some(3)]).into_array();
86 let taken = take(array.as_ref(), indices.as_ref()).unwrap();
87
88 assert!(taken.scalar_at(0).is_null());
89 assert_eq!(
90 taken.scalar_at(1),
91 Scalar::decimal(
92 DecimalValue::I128(12i128),
93 array.decimal_dtype(),
94 Nullability::Nullable
95 )
96 );
97
98 assert_eq!(
99 taken.scalar_at(2),
100 Scalar::decimal(
101 DecimalValue::I128(13i128),
102 array.decimal_dtype(),
103 Nullability::Nullable
104 )
105 );
106 }
107
108 #[rstest]
109 #[case(DecimalArray::new(
110 buffer![100i128, 200i128, 300i128, 400i128, 500i128],
111 DecimalDType::new(19, 2),
112 Validity::NonNullable,
113 ))]
114 #[case(DecimalArray::new(
115 buffer![10i64, 20i64, 30i64, 40i64, 50i64],
116 DecimalDType::new(10, 1),
117 Validity::NonNullable,
118 ))]
119 #[case(DecimalArray::new(
120 buffer![1i32, 2i32, 3i32, 4i32, 5i32],
121 DecimalDType::new(5, 0),
122 Validity::NonNullable,
123 ))]
124 #[case(DecimalArray::new(
125 buffer![1000i128, 2000i128, 3000i128, 4000i128, 5000i128],
126 DecimalDType::new(19, 3),
127 Validity::from_iter([true, false, true, true, false]),
128 ))]
129 #[case(DecimalArray::new(
130 buffer![42i128],
131 DecimalDType::new(19, 0),
132 Validity::NonNullable,
133 ))]
134 #[case({
135 let values: Vec<i128> = (0..100).map(|i| i * 1000).collect();
136 DecimalArray::new(
137 Buffer::from_iter(values),
138 DecimalDType::new(19, 4),
139 Validity::NonNullable,
140 )
141 })]
142 fn test_take_decimal_conformance(#[case] array: DecimalArray) {
143 test_take_conformance(array.as_ref());
144 }
145}