vortex_array/arrays/decimal/compute/
take.rs1use vortex_buffer::Buffer;
5use vortex_dtype::{
6 IntegerPType, NativeDecimalType, match_each_decimal_value_type, match_each_integer_ptype,
7};
8use vortex_error::VortexResult;
9
10use crate::arrays::{DecimalArray, DecimalVTable};
11use crate::compute::{TakeKernel, TakeKernelAdapter};
12use crate::vtable::ValidityHelper;
13use crate::{Array, ArrayRef, ToCanonical, register_kernel};
14
15impl TakeKernel for DecimalVTable {
16 fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17 let indices = indices.to_primitive();
18 let validity = array.validity().take(indices.as_ref())?;
19
20 let decimal = match_each_decimal_value_type!(array.values_type(), |D| {
23 match_each_integer_ptype!(indices.ptype(), |I| {
24 let buffer =
25 take_to_buffer::<I, D>(indices.as_slice::<I>(), array.buffer::<D>().as_slice());
26 unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) }
29 })
30 });
31
32 Ok(decimal.to_array())
33 }
34}
35
36register_kernel!(TakeKernelAdapter(DecimalVTable).lift());
37
38#[inline]
39fn take_to_buffer<I: IntegerPType, T: NativeDecimalType>(indices: &[I], values: &[T]) -> Buffer<T> {
40 indices.iter().map(|idx| values[idx.as_()]).collect()
41}
42
43#[cfg(test)]
44mod tests {
45 use rstest::rstest;
46 use vortex_buffer::{Buffer, buffer};
47 use vortex_dtype::{DecimalDType, Nullability};
48 use vortex_scalar::{DecimalValue, Scalar};
49
50 use crate::IntoArray;
51 use crate::arrays::{DecimalArray, DecimalVTable, PrimitiveArray};
52 use crate::compute::conformance::take::test_take_conformance;
53 use crate::compute::take;
54 use crate::validity::Validity;
55
56 #[test]
57 fn test_take() {
58 let array = DecimalArray::new(
59 buffer![10i128, 11i128, 12i128, 13i128],
60 DecimalDType::new(19, 1),
61 Validity::NonNullable,
62 );
63
64 let indices = buffer![0, 2, 3].into_array();
65 let taken = take(array.as_ref(), indices.as_ref()).unwrap();
66 let taken_decimals = taken.as_::<DecimalVTable>();
67 assert_eq!(
68 taken_decimals.buffer::<i128>(),
69 buffer![10i128, 12i128, 13i128]
70 );
71 assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
72 }
73
74 #[test]
75 fn test_take_null_indices() {
76 let array = DecimalArray::new(
77 buffer![i128::MAX, 11i128, 12i128, 13i128],
78 DecimalDType::new(19, 1),
79 Validity::NonNullable,
80 );
81
82 let indices = PrimitiveArray::from_option_iter([None, Some(2), Some(3)]).into_array();
83 let taken = take(array.as_ref(), indices.as_ref()).unwrap();
84
85 assert!(taken.scalar_at(0).is_null());
86 assert_eq!(
87 taken.scalar_at(1),
88 Scalar::decimal(
89 DecimalValue::I128(12i128),
90 array.decimal_dtype(),
91 Nullability::Nullable
92 )
93 );
94
95 assert_eq!(
96 taken.scalar_at(2),
97 Scalar::decimal(
98 DecimalValue::I128(13i128),
99 array.decimal_dtype(),
100 Nullability::Nullable
101 )
102 );
103 }
104
105 #[rstest]
106 #[case(DecimalArray::new(
107 buffer![100i128, 200i128, 300i128, 400i128, 500i128],
108 DecimalDType::new(19, 2),
109 Validity::NonNullable,
110 ))]
111 #[case(DecimalArray::new(
112 buffer![10i64, 20i64, 30i64, 40i64, 50i64],
113 DecimalDType::new(10, 1),
114 Validity::NonNullable,
115 ))]
116 #[case(DecimalArray::new(
117 buffer![1i32, 2i32, 3i32, 4i32, 5i32],
118 DecimalDType::new(5, 0),
119 Validity::NonNullable,
120 ))]
121 #[case(DecimalArray::new(
122 buffer![1000i128, 2000i128, 3000i128, 4000i128, 5000i128],
123 DecimalDType::new(19, 3),
124 Validity::from_iter([true, false, true, true, false]),
125 ))]
126 #[case(DecimalArray::new(
127 buffer![42i128],
128 DecimalDType::new(19, 0),
129 Validity::NonNullable,
130 ))]
131 #[case({
132 let values: Vec<i128> = (0..100).map(|i| i * 1000).collect();
133 DecimalArray::new(
134 Buffer::from_iter(values),
135 DecimalDType::new(19, 4),
136 Validity::NonNullable,
137 )
138 })]
139 fn test_take_decimal_conformance(#[case] array: DecimalArray) {
140 test_take_conformance(array.as_ref());
141 }
142}