vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_buffer::{BitBufferMut, Buffer, BufferMut, ByteBuffer};
6use vortex_dtype::{
7    BigCast, DType, DecimalDType, DecimalType, IntegerPType, NativeDecimalType,
8    match_each_decimal_value_type, match_each_integer_ptype,
9};
10use vortex_error::{VortexExpect, VortexResult, vortex_ensure, vortex_panic};
11
12use crate::ToCanonical;
13use crate::patches::Patches;
14use crate::stats::ArrayStats;
15use crate::validity::Validity;
16use crate::vtable::ValidityHelper;
17
18/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
19///
20/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
21/// financial and scientific computations where floating-point precision loss is unacceptable.
22///
23/// ## Storage Format
24///
25/// Decimals are stored as scaled integers in a supported scalar value type.
26///
27/// The precisions supported for each scalar type are:
28/// - **i8**: precision 1-2 digits
29/// - **i16**: precision 3-4 digits
30/// - **i32**: precision 5-9 digits
31/// - **i64**: precision 10-18 digits
32/// - **i128**: precision 19-38 digits
33/// - **i256**: precision 39-76 digits
34///
35/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
36/// values with precision that does not match this exactly. For example, a valid DecimalArray with
37/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
38///
39/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
40/// `Buffer<i256>`.
41///
42/// ## Precision and Scale
43///
44/// - **Precision**: Total number of significant digits (1-76, u8 range)
45/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
46/// - **Value**: `stored_integer / 10^scale`
47///
48/// For example, with precision=5 and scale=2:
49/// - Stored value 12345 represents 123.45
50/// - Range: -999.99 to 999.99
51///
52/// ## Valid Scalar Types
53///
54/// The underlying storage uses these native types based on precision:
55/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
56/// - Type selection is automatic based on the required precision
57///
58/// # Examples
59///
60/// ```
61/// use vortex_array::arrays::DecimalArray;
62/// use vortex_dtype::DecimalDType;
63/// use vortex_buffer::{buffer, Buffer};
64/// use vortex_array::validity::Validity;
65///
66/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
67/// let decimal_dtype = DecimalDType::new(5, 2);
68/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
69/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
70///
71/// assert_eq!(array.precision(), 5);
72/// assert_eq!(array.scale(), 2);
73/// assert_eq!(array.len(), 3);
74/// ```
75#[derive(Clone, Debug)]
76pub struct DecimalArray {
77    pub(super) dtype: DType,
78    pub(super) values: ByteBuffer,
79    pub(super) values_type: DecimalType,
80    pub(super) validity: Validity,
81    pub(super) stats_set: ArrayStats,
82}
83
84impl DecimalArray {
85    /// Creates a new [`DecimalArray`].
86    ///
87    /// # Panics
88    ///
89    /// Panics if the provided components do not satisfy the invariants documented in
90    /// [`DecimalArray::new_unchecked`].
91    pub fn new<T: NativeDecimalType>(
92        buffer: Buffer<T>,
93        decimal_dtype: DecimalDType,
94        validity: Validity,
95    ) -> Self {
96        Self::try_new(buffer, decimal_dtype, validity)
97            .vortex_expect("DecimalArray construction failed")
98    }
99
100    /// Constructs a new `DecimalArray`.
101    ///
102    /// See [`DecimalArray::new_unchecked`] for more information.
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if the provided components do not satisfy the invariants documented in
107    /// [`DecimalArray::new_unchecked`].
108    pub fn try_new<T: NativeDecimalType>(
109        buffer: Buffer<T>,
110        decimal_dtype: DecimalDType,
111        validity: Validity,
112    ) -> VortexResult<Self> {
113        Self::validate(&buffer, &validity)?;
114
115        // SAFETY: validate ensures all invariants are met.
116        Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) })
117    }
118
119    /// Creates a new [`DecimalArray`] without validation from these components:
120    ///
121    /// * `buffer` is a typed buffer containing the decimal values.
122    /// * `decimal_dtype` specifies the decimal precision and scale.
123    /// * `validity` holds the null values.
124    ///
125    /// # Safety
126    ///
127    /// The caller must ensure all of the following invariants are satisfied:
128    ///
129    /// - All non-null values in `buffer` must be representable within the specified precision.
130    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
131    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
132    pub unsafe fn new_unchecked<T: NativeDecimalType>(
133        buffer: Buffer<T>,
134        decimal_dtype: DecimalDType,
135        validity: Validity,
136    ) -> Self {
137        #[cfg(debug_assertions)]
138        Self::validate(&buffer, &validity)
139            .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
140
141        Self {
142            values: buffer.into_byte_buffer(),
143            values_type: T::DECIMAL_TYPE,
144            dtype: DType::Decimal(decimal_dtype, validity.nullability()),
145            validity,
146            stats_set: Default::default(),
147        }
148    }
149
150    /// Validates the components that would be used to create a [`DecimalArray`].
151    ///
152    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
153    pub fn validate<T: NativeDecimalType>(
154        buffer: &Buffer<T>,
155        validity: &Validity,
156    ) -> VortexResult<()> {
157        if let Some(len) = validity.maybe_len() {
158            vortex_ensure!(
159                buffer.len() == len,
160                "Buffer and validity length mismatch: buffer={}, validity={}",
161                buffer.len(),
162                len,
163            );
164        }
165
166        Ok(())
167    }
168
169    /// Returns the underlying [`ByteBuffer`] of the array.
170    pub fn byte_buffer(&self) -> ByteBuffer {
171        self.values.clone()
172    }
173
174    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
175        if self.values_type != T::DECIMAL_TYPE {
176            vortex_panic!(
177                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
178                T::DECIMAL_TYPE,
179                self.values_type,
180            );
181        }
182        Buffer::<T>::from_byte_buffer(self.values.clone())
183    }
184
185    /// Returns the decimal type information
186    pub fn decimal_dtype(&self) -> DecimalDType {
187        if let DType::Decimal(decimal_dtype, _) = self.dtype {
188            decimal_dtype
189        } else {
190            vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
191        }
192    }
193
194    pub fn values_type(&self) -> DecimalType {
195        self.values_type
196    }
197
198    pub fn precision(&self) -> u8 {
199        self.decimal_dtype().precision()
200    }
201
202    pub fn scale(&self) -> i8 {
203        self.decimal_dtype().scale()
204    }
205
206    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
207        iter: I,
208        decimal_dtype: DecimalDType,
209    ) -> Self {
210        let iter = iter.into_iter();
211
212        Self::new(
213            BufferMut::from_iter(iter).freeze(),
214            decimal_dtype,
215            Validity::NonNullable,
216        )
217    }
218
219    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
220        iter: I,
221        decimal_dtype: DecimalDType,
222    ) -> Self {
223        let iter = iter.into_iter();
224        let mut values = BufferMut::with_capacity(iter.size_hint().0);
225        let mut validity = BitBufferMut::with_capacity(values.capacity());
226
227        for i in iter {
228            match i {
229                None => {
230                    validity.append(false);
231                    values.push(T::default());
232                }
233                Some(e) => {
234                    validity.append(true);
235                    values.push(e);
236                }
237            }
238        }
239        Self::new(
240            values.freeze(),
241            decimal_dtype,
242            Validity::from(validity.freeze()),
243        )
244    }
245
246    #[allow(clippy::cognitive_complexity)]
247    pub fn patch(self, patches: &Patches) -> Self {
248        let offset = patches.offset();
249        let patch_indices = patches.indices().to_primitive();
250        let patch_values = patches.values().to_decimal();
251
252        let patched_validity = self.validity().clone().patch(
253            self.len(),
254            offset,
255            patch_indices.as_ref(),
256            patch_values.validity(),
257        );
258        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
259
260        match_each_integer_ptype!(patch_indices.ptype(), |I| {
261            let patch_indices = patch_indices.as_slice::<I>();
262            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
263                let patch_values = patch_values.buffer::<PatchDVT>();
264                match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
265                    let buffer = self.buffer::<ValuesDVT>().into_mut();
266                    patch_typed(
267                        buffer,
268                        self.decimal_dtype(),
269                        patch_indices,
270                        offset,
271                        patch_values,
272                        patched_validity,
273                    )
274                })
275            })
276        })
277    }
278}
279
280fn patch_typed<I, ValuesDVT, PatchDVT>(
281    mut buffer: BufferMut<ValuesDVT>,
282    decimal_dtype: DecimalDType,
283    patch_indices: &[I],
284    patch_indices_offset: usize,
285    patch_values: Buffer<PatchDVT>,
286    patched_validity: Validity,
287) -> DecimalArray
288where
289    I: IntegerPType,
290    PatchDVT: NativeDecimalType,
291    ValuesDVT: NativeDecimalType,
292{
293    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
294        vortex_panic!(
295            "patch_typed: {:?} cannot represent every value in {}.",
296            ValuesDVT::DECIMAL_TYPE,
297            decimal_dtype
298        )
299    }
300
301    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
302        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
303            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
304        );
305    }
306
307    DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
308}