vortex_array/arrays/decimal/compute/
take.rs

1use num_traits::AsPrimitive;
2use vortex_buffer::Buffer;
3use vortex_dtype::{NativePType, match_each_integer_ptype};
4use vortex_error::{VortexResult, vortex_err};
5
6use crate::arrays::{DecimalArray, DecimalEncoding, NativeDecimalType, PrimitiveArray};
7use crate::compute::TakeFn;
8use crate::variants::PrimitiveArrayTrait;
9use crate::{Array, ArrayRef, match_each_decimal_value_type};
10
11impl TakeFn<&DecimalArray> for DecimalEncoding {
12    fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
13        let indices = indices
14            .as_any()
15            .downcast_ref::<PrimitiveArray>()
16            .ok_or_else(|| vortex_err!("indices must be a PrimitiveArray"))?;
17
18        let decimal = match_each_decimal_value_type!(array.values_type(), |$D| {
19                match_each_integer_ptype!(indices.ptype(), |$I| {
20                    let buffer = take_to_buffer::<$I, $D>(indices.as_slice::<$I>(), array.buffer::<$D>().as_slice());
21                    DecimalArray::new(buffer, array.decimal_dtype(), array.validity().clone())
22                })
23        });
24
25        Ok(decimal.to_array())
26    }
27}
28
29#[inline]
30fn take_to_buffer<I: NativePType + AsPrimitive<usize>, T: NativeDecimalType>(
31    indices: &[I],
32    values: &[T],
33) -> Buffer<T> {
34    indices.iter().map(|idx| values[idx.as_()]).collect()
35}
36
37#[cfg(test)]
38mod tests {
39    use vortex_buffer::buffer;
40    use vortex_dtype::DecimalDType;
41
42    use crate::arrays::DecimalArray;
43    use crate::compute::take;
44    use crate::validity::Validity;
45    use crate::{Array, IntoArray};
46
47    #[test]
48    fn test_take() {
49        let array = DecimalArray::new(
50            buffer![10i128, 11i128, 12i128, 13i128],
51            DecimalDType::new(19, 1),
52            Validity::NonNullable,
53        );
54
55        let indices = buffer![0, 2, 3].into_array();
56        let taken = take(&array, indices.as_ref()).unwrap();
57        let taken_decimals = taken.as_any().downcast_ref::<DecimalArray>().unwrap();
58        assert_eq!(
59            taken_decimals.buffer::<i128>(),
60            buffer![10i128, 12i128, 13i128]
61        );
62        assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));
63    }
64}