vortex_array/arrays/decimal/
ops.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::DecimalDType;
6use vortex_error::VortexResult;
7use vortex_scalar::{DecimalValue, NativeDecimalType, Scalar, match_each_decimal_value_type};
8
9use crate::arrays::{DecimalArray, DecimalVTable};
10use crate::validity::Validity;
11use crate::vtable::OperationsVTable;
12use crate::{ArrayRef, IntoArray};
13
14impl OperationsVTable<DecimalVTable> for DecimalVTable {
15    fn slice(array: &DecimalArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
16        match_each_decimal_value_type!(array.values_type(), |D| {
17            slice_typed(
18                array.buffer::<D>(),
19                start,
20                stop,
21                array.decimal_dtype(),
22                array.validity.clone(),
23            )
24        })
25    }
26
27    fn scalar_at(array: &DecimalArray, index: usize) -> VortexResult<Scalar> {
28        let scalar = match_each_decimal_value_type!(array.values_type(), |D| {
29            Scalar::decimal(
30                DecimalValue::from(array.buffer::<D>()[index]),
31                array.decimal_dtype(),
32                array.dtype().nullability(),
33            )
34        });
35        Ok(scalar)
36    }
37}
38
39fn slice_typed<T: NativeDecimalType>(
40    values: Buffer<T>,
41    start: usize,
42    end: usize,
43    decimal_dtype: DecimalDType,
44    validity: Validity,
45) -> VortexResult<ArrayRef> {
46    let sliced = values.slice(start..end);
47    let validity = validity.slice(start, end)?;
48    Ok(DecimalArray::new(sliced, decimal_dtype, validity).into_array())
49}
50
51#[cfg(test)]
52mod tests {
53    use vortex_buffer::buffer;
54    use vortex_dtype::{DecimalDType, Nullability};
55    use vortex_scalar::{DecimalValue, Scalar};
56
57    use crate::Array;
58    use crate::arrays::{DecimalArray, DecimalVTable};
59    use crate::validity::Validity;
60
61    #[test]
62    fn test_slice() {
63        let array = DecimalArray::new(
64            buffer![100i128, 200i128, 300i128, 4000i128],
65            DecimalDType::new(3, 2),
66            Validity::NonNullable,
67        )
68        .to_array();
69
70        let sliced = array.slice(1, 3).unwrap();
71        assert_eq!(sliced.len(), 2);
72
73        let decimal = sliced.as_::<DecimalVTable>();
74        assert_eq!(decimal.buffer::<i128>(), buffer![200i128, 300i128]);
75    }
76
77    #[test]
78    fn test_slice_nullable() {
79        let array = DecimalArray::new(
80            buffer![100i128, 200i128, 300i128, 4000i128],
81            DecimalDType::new(3, 2),
82            Validity::from_iter([false, true, false, true]),
83        )
84        .to_array();
85
86        let sliced = array.slice(1, 3).unwrap();
87        assert_eq!(sliced.len(), 2);
88    }
89
90    #[test]
91    fn test_scalar_at() {
92        let array = DecimalArray::new(
93            buffer![100i128],
94            DecimalDType::new(3, 2),
95            Validity::NonNullable,
96        );
97
98        assert_eq!(
99            array.scalar_at(0).unwrap(),
100            Scalar::decimal(
101                DecimalValue::I128(100),
102                DecimalDType::new(3, 2),
103                Nullability::NonNullable
104            )
105        );
106    }
107}