vortex_array/arrays/decimal/compute/
take.rs1use vortex_buffer::Buffer;
5use vortex_dtype::{IntegerPType, match_each_integer_ptype};
6use vortex_error::VortexResult;
7use vortex_scalar::{NativeDecimalType, match_each_decimal_value_type};
8
9use crate::arrays::{DecimalArray, DecimalVTable};
10use crate::compute::{TakeKernel, TakeKernelAdapter};
11use crate::vtable::ValidityHelper;
12use crate::{Array, ArrayRef, ToCanonical, register_kernel};
13
14impl TakeKernel for DecimalVTable {
15 fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
16 let indices = indices.to_primitive();
17 let validity = array.validity().take(indices.as_ref())?;
18
19 let decimal = match_each_decimal_value_type!(array.values_type(), |D| {
22 match_each_integer_ptype!(indices.ptype(), |I| {
23 let buffer =
24 take_to_buffer::<I, D>(indices.as_slice::<I>(), array.buffer::<D>().as_slice());
25 unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) }
28 })
29 });
30
31 Ok(decimal.to_array())
32 }
33}
34
35register_kernel!(TakeKernelAdapter(DecimalVTable).lift());
36
37#[inline]
38fn take_to_buffer<I: IntegerPType, T: NativeDecimalType>(indices: &[I], values: &[T]) -> Buffer<T> {
39 indices.iter().map(|idx| values[idx.as_()]).collect()
40}
41
42#[cfg(test)]
43mod tests {
44 use rstest::rstest;
45 use vortex_buffer::{Buffer, buffer};
46 use vortex_dtype::{DecimalDType, Nullability};
47 use vortex_scalar::{DecimalValue, Scalar};
48
49 use crate::IntoArray;
50 use crate::arrays::{DecimalArray, DecimalVTable, PrimitiveArray};
51 use crate::compute::conformance::take::test_take_conformance;
52 use crate::compute::take;
53 use crate::validity::Validity;
54
55 #[test]
56 fn test_take() {
57 let array = DecimalArray::new(
58 buffer![10i128, 11i128, 12i128, 13i128],
59 DecimalDType::new(19, 1),
60 Validity::NonNullable,
61 );
62
63 let indices = buffer![0, 2, 3].into_array();
64 let taken = take(array.as_ref(), indices.as_ref()).unwrap();
65 let taken_decimals = taken.as_::<DecimalVTable>();
66 assert_eq!(
67 taken_decimals.buffer::<i128>(),
68 buffer![10i128, 12i128, 13i128]
69 );
70 assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
71 }
72
73 #[test]
74 fn test_take_null_indices() {
75 let array = DecimalArray::new(
76 buffer![i128::MAX, 11i128, 12i128, 13i128],
77 DecimalDType::new(19, 1),
78 Validity::NonNullable,
79 );
80
81 let indices = PrimitiveArray::from_option_iter([None, Some(2), Some(3)]).into_array();
82 let taken = take(array.as_ref(), indices.as_ref()).unwrap();
83
84 assert!(taken.scalar_at(0).is_null());
85 assert_eq!(
86 taken.scalar_at(1),
87 Scalar::decimal(
88 DecimalValue::I128(12i128),
89 array.decimal_dtype(),
90 Nullability::Nullable
91 )
92 );
93
94 assert_eq!(
95 taken.scalar_at(2),
96 Scalar::decimal(
97 DecimalValue::I128(13i128),
98 array.decimal_dtype(),
99 Nullability::Nullable
100 )
101 );
102 }
103
104 #[rstest]
105 #[case(DecimalArray::new(
106 buffer![100i128, 200i128, 300i128, 400i128, 500i128],
107 DecimalDType::new(19, 2),
108 Validity::NonNullable,
109 ))]
110 #[case(DecimalArray::new(
111 buffer![10i64, 20i64, 30i64, 40i64, 50i64],
112 DecimalDType::new(10, 1),
113 Validity::NonNullable,
114 ))]
115 #[case(DecimalArray::new(
116 buffer![1i32, 2i32, 3i32, 4i32, 5i32],
117 DecimalDType::new(5, 0),
118 Validity::NonNullable,
119 ))]
120 #[case(DecimalArray::new(
121 buffer![1000i128, 2000i128, 3000i128, 4000i128, 5000i128],
122 DecimalDType::new(19, 3),
123 Validity::from_iter([true, false, true, true, false]),
124 ))]
125 #[case(DecimalArray::new(
126 buffer![42i128],
127 DecimalDType::new(19, 0),
128 Validity::NonNullable,
129 ))]
130 #[case({
131 let values: Vec<i128> = (0..100).map(|i| i * 1000).collect();
132 DecimalArray::new(
133 Buffer::from_iter(values),
134 DecimalDType::new(19, 4),
135 Validity::NonNullable,
136 )
137 })]
138 fn test_take_decimal_conformance(#[case] array: DecimalArray) {
139 test_take_conformance(array.as_ref());
140 }
141}