vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBufferBuilder;
5use itertools::Itertools;
6use vortex_buffer::{Buffer, BufferMut, ByteBuffer};
7use vortex_dtype::{DType, DecimalDType, IntegerPType, match_each_integer_ptype};
8use vortex_error::{VortexExpect, VortexResult, vortex_ensure, vortex_panic};
9use vortex_scalar::{BigCast, DecimalValueType, NativeDecimalType, match_each_decimal_value_type};
10
11use crate::ToCanonical;
12use crate::arrays::is_compatible_decimal_value_type;
13use crate::patches::Patches;
14use crate::stats::ArrayStats;
15use crate::validity::Validity;
16use crate::vtable::ValidityHelper;
17
18/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
19///
20/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
21/// financial and scientific computations where floating-point precision loss is unacceptable.
22///
23/// ## Storage Format
24///
25/// Decimals are stored as scaled integers in a supported scalar value type.
26///
27/// The precisions supported for each scalar type are:
28/// - **i8**: precision 1-2 digits
29/// - **i16**: precision 3-4 digits
30/// - **i32**: precision 5-9 digits
31/// - **i64**: precision 10-18 digits
32/// - **i128**: precision 19-38 digits
33/// - **i256**: precision 39-76 digits
34///
35/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
36/// values with precision that does not match this exactly. For example, a valid DecimalArray with
37/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
38///
39/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
40/// `Buffer<i256>`.
41///
42/// ## Precision and Scale
43///
44/// - **Precision**: Total number of significant digits (1-76, u8 range)
45/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
46/// - **Value**: `stored_integer / 10^scale`
47///
48/// For example, with precision=5 and scale=2:
49/// - Stored value 12345 represents 123.45
50/// - Range: -999.99 to 999.99
51///
52/// ## Valid Scalar Types
53///
54/// The underlying storage uses these native types based on precision:
55/// - `DecimalValueType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
56/// - Type selection is automatic based on the required precision
57///
58/// # Examples
59///
60/// ```
61/// use vortex_array::arrays::DecimalArray;
62/// use vortex_dtype::DecimalDType;
63/// use vortex_buffer::{buffer, Buffer};
64/// use vortex_array::validity::Validity;
65///
66/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
67/// let decimal_dtype = DecimalDType::new(5, 2);
68/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
69/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
70///
71/// assert_eq!(array.precision(), 5);
72/// assert_eq!(array.scale(), 2);
73/// assert_eq!(array.len(), 3);
74/// ```
75#[derive(Clone, Debug)]
76pub struct DecimalArray {
77    pub(super) dtype: DType,
78    pub(super) values: ByteBuffer,
79    pub(super) values_type: DecimalValueType,
80    pub(super) validity: Validity,
81    pub(super) stats_set: ArrayStats,
82}
83
84impl DecimalArray {
85    /// Creates a new [`DecimalArray`].
86    ///
87    /// # Panics
88    ///
89    /// Panics if the provided components do not satisfy the invariants documented in
90    /// [`DecimalArray::new_unchecked`].
91    pub fn new<T: NativeDecimalType>(
92        buffer: Buffer<T>,
93        decimal_dtype: DecimalDType,
94        validity: Validity,
95    ) -> Self {
96        Self::try_new(buffer, decimal_dtype, validity)
97            .vortex_expect("DecimalArray construction failed")
98    }
99
100    /// Constructs a new `DecimalArray`.
101    ///
102    /// See [`DecimalArray::new_unchecked`] for more information.
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if the provided components do not satisfy the invariants documented in
107    /// [`DecimalArray::new_unchecked`].
108    pub fn try_new<T: NativeDecimalType>(
109        buffer: Buffer<T>,
110        decimal_dtype: DecimalDType,
111        validity: Validity,
112    ) -> VortexResult<Self> {
113        Self::validate(&buffer, &validity)?;
114
115        // SAFETY: validate ensures all invariants are met.
116        Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) })
117    }
118
119    /// Creates a new [`DecimalArray`] without validation from these components:
120    ///
121    /// * `buffer` is a typed buffer containing the decimal values.
122    /// * `decimal_dtype` specifies the decimal precision and scale.
123    /// * `validity` holds the null values.
124    ///
125    /// # Safety
126    ///
127    /// The caller must ensure all of the following invariants are satisfied:
128    ///
129    /// - All non-null values in `buffer` must be representable within the specified precision.
130    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
131    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
132    pub unsafe fn new_unchecked<T: NativeDecimalType>(
133        buffer: Buffer<T>,
134        decimal_dtype: DecimalDType,
135        validity: Validity,
136    ) -> Self {
137        #[cfg(debug_assertions)]
138        Self::validate(&buffer, &validity)
139            .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
140
141        Self {
142            values: buffer.into_byte_buffer(),
143            values_type: T::VALUES_TYPE,
144            dtype: DType::Decimal(decimal_dtype, validity.nullability()),
145            validity,
146            stats_set: Default::default(),
147        }
148    }
149
150    /// Validates the components that would be used to create a [`DecimalArray`].
151    ///
152    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
153    pub fn validate<T: NativeDecimalType>(
154        buffer: &Buffer<T>,
155        validity: &Validity,
156    ) -> VortexResult<()> {
157        if let Some(len) = validity.maybe_len() {
158            vortex_ensure!(
159                buffer.len() == len,
160                "Buffer and validity length mismatch: buffer={}, validity={}",
161                buffer.len(),
162                len,
163            );
164        }
165
166        Ok(())
167    }
168
169    /// Returns the underlying [`ByteBuffer`] of the array.
170    pub fn byte_buffer(&self) -> ByteBuffer {
171        self.values.clone()
172    }
173
174    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
175        if self.values_type != T::VALUES_TYPE {
176            vortex_panic!(
177                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
178                T::VALUES_TYPE,
179                self.values_type,
180            );
181        }
182        Buffer::<T>::from_byte_buffer(self.values.clone())
183    }
184
185    /// Returns the decimal type information
186    pub fn decimal_dtype(&self) -> DecimalDType {
187        if let DType::Decimal(decimal_dtype, _) = self.dtype {
188            decimal_dtype
189        } else {
190            vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
191        }
192    }
193
194    pub fn values_type(&self) -> DecimalValueType {
195        self.values_type
196    }
197
198    pub fn precision(&self) -> u8 {
199        self.decimal_dtype().precision()
200    }
201
202    pub fn scale(&self) -> i8 {
203        self.decimal_dtype().scale()
204    }
205
206    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
207        iter: I,
208        decimal_dtype: DecimalDType,
209    ) -> Self {
210        let iter = iter.into_iter();
211        let mut values = BufferMut::with_capacity(iter.size_hint().0);
212        let mut validity = BooleanBufferBuilder::new(values.capacity());
213
214        for i in iter {
215            match i {
216                None => {
217                    validity.append(false);
218                    values.push(T::default());
219                }
220                Some(e) => {
221                    validity.append(true);
222                    values.push(e);
223                }
224            }
225        }
226        Self::new(
227            values.freeze(),
228            decimal_dtype,
229            Validity::from(validity.finish()),
230        )
231    }
232
233    #[allow(clippy::cognitive_complexity)]
234    pub fn patch(self, patches: &Patches) -> Self {
235        let offset = patches.offset();
236        let patch_indices = patches.indices().to_primitive();
237        let patch_values = patches.values().to_decimal();
238
239        let patched_validity = self.validity().clone().patch(
240            self.len(),
241            offset,
242            patch_indices.as_ref(),
243            patch_values.validity(),
244        );
245        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
246
247        match_each_integer_ptype!(patch_indices.ptype(), |I| {
248            let patch_indices = patch_indices.as_slice::<I>();
249            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
250                let patch_values = patch_values.buffer::<PatchDVT>();
251                match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
252                    let buffer = self.buffer::<ValuesDVT>().into_mut();
253                    patch_typed(
254                        buffer,
255                        self.decimal_dtype(),
256                        patch_indices,
257                        offset,
258                        patch_values,
259                        patched_validity,
260                    )
261                })
262            })
263        })
264    }
265}
266
267fn patch_typed<I, ValuesDVT, PatchDVT>(
268    mut buffer: BufferMut<ValuesDVT>,
269    decimal_dtype: DecimalDType,
270    patch_indices: &[I],
271    patch_indices_offset: usize,
272    patch_values: Buffer<PatchDVT>,
273    patched_validity: Validity,
274) -> DecimalArray
275where
276    I: IntegerPType,
277    PatchDVT: NativeDecimalType,
278    ValuesDVT: NativeDecimalType,
279{
280    if !is_compatible_decimal_value_type(ValuesDVT::VALUES_TYPE, decimal_dtype) {
281        vortex_panic!(
282            "patch_typed: {:?} cannot represent every value in {}.",
283            ValuesDVT::VALUES_TYPE,
284            decimal_dtype
285        )
286    }
287
288    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
289        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
290            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
291        );
292    }
293
294    DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
295}