vortex_array/arrays/decimal/
ops.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_buffer::Buffer;
7use vortex_dtype::DecimalDType;
8use vortex_scalar::{DecimalValue, NativeDecimalType, Scalar, match_each_decimal_value_type};
9
10use crate::arrays::{DecimalArray, DecimalVTable};
11use crate::validity::Validity;
12use crate::vtable::OperationsVTable;
13use crate::{ArrayRef, IntoArray};
14
15impl OperationsVTable<DecimalVTable> for DecimalVTable {
16    fn slice(array: &DecimalArray, range: Range<usize>) -> ArrayRef {
17        match_each_decimal_value_type!(array.values_type(), |D| {
18            slice_typed(
19                array.buffer::<D>(),
20                range,
21                array.decimal_dtype(),
22                array.validity.clone(),
23            )
24        })
25    }
26
27    fn scalar_at(array: &DecimalArray, index: usize) -> Scalar {
28        match_each_decimal_value_type!(array.values_type(), |D| {
29            Scalar::decimal(
30                DecimalValue::from(array.buffer::<D>()[index]),
31                array.decimal_dtype(),
32                array.dtype().nullability(),
33            )
34        })
35    }
36}
37
38fn slice_typed<T: NativeDecimalType>(
39    values: Buffer<T>,
40    range: Range<usize>,
41    decimal_dtype: DecimalDType,
42    validity: Validity,
43) -> ArrayRef {
44    let sliced = values.slice(range.clone());
45    let validity = validity.slice(range);
46    // SAFETY: Slicing preserves all DecimalArray invariants:
47    // - Buffer is correctly typed and sized from the slice operation.
48    // - Decimal dtype is preserved from the parent array.
49    // - Validity is correctly sliced to match the new length.
50    unsafe { DecimalArray::new_unchecked(sliced, decimal_dtype, validity) }.into_array()
51}
52
53#[cfg(test)]
54mod tests {
55    use vortex_buffer::buffer;
56    use vortex_dtype::{DecimalDType, Nullability};
57    use vortex_scalar::{DecimalValue, Scalar};
58
59    use crate::Array;
60    use crate::arrays::{DecimalArray, DecimalVTable};
61    use crate::validity::Validity;
62
63    #[test]
64    fn test_slice() {
65        let array = DecimalArray::new(
66            buffer![100i128, 200i128, 300i128, 4000i128],
67            DecimalDType::new(3, 2),
68            Validity::NonNullable,
69        )
70        .to_array();
71
72        let sliced = array.slice(1..3);
73        assert_eq!(sliced.len(), 2);
74
75        let decimal = sliced.as_::<DecimalVTable>();
76        assert_eq!(decimal.buffer::<i128>(), buffer![200i128, 300i128]);
77    }
78
79    #[test]
80    fn test_slice_nullable() {
81        let array = DecimalArray::new(
82            buffer![100i128, 200i128, 300i128, 4000i128],
83            DecimalDType::new(3, 2),
84            Validity::from_iter([false, true, false, true]),
85        )
86        .to_array();
87
88        let sliced = array.slice(1..3);
89        assert_eq!(sliced.len(), 2);
90    }
91
92    #[test]
93    fn test_scalar_at() {
94        let array = DecimalArray::new(
95            buffer![100i128],
96            DecimalDType::new(3, 2),
97            Validity::NonNullable,
98        );
99
100        assert_eq!(
101            array.scalar_at(0),
102            Scalar::decimal(
103                DecimalValue::I128(100),
104                DecimalDType::new(3, 2),
105                Nullability::NonNullable
106            )
107        );
108    }
109}