vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::BigCast;
10use vortex_dtype::DType;
11use vortex_dtype::DecimalDType;
12use vortex_dtype::DecimalType;
13use vortex_dtype::IntegerPType;
14use vortex_dtype::NativeDecimalType;
15use vortex_dtype::match_each_decimal_value_type;
16use vortex_dtype::match_each_integer_ptype;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_ensure;
20use vortex_error::vortex_panic;
21
22use crate::ToCanonical;
23use crate::patches::Patches;
24use crate::stats::ArrayStats;
25use crate::validity::Validity;
26use crate::vtable::ValidityHelper;
27
28/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
29///
30/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
31/// financial and scientific computations where floating-point precision loss is unacceptable.
32///
33/// ## Storage Format
34///
35/// Decimals are stored as scaled integers in a supported scalar value type.
36///
37/// The precisions supported for each scalar type are:
38/// - **i8**: precision 1-2 digits
39/// - **i16**: precision 3-4 digits
40/// - **i32**: precision 5-9 digits
41/// - **i64**: precision 10-18 digits
42/// - **i128**: precision 19-38 digits
43/// - **i256**: precision 39-76 digits
44///
45/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
46/// values with precision that does not match this exactly. For example, a valid DecimalArray with
47/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
48///
49/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
50/// `Buffer<i256>`.
51///
52/// ## Precision and Scale
53///
54/// - **Precision**: Total number of significant digits (1-76, u8 range)
55/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
56/// - **Value**: `stored_integer / 10^scale`
57///
58/// For example, with precision=5 and scale=2:
59/// - Stored value 12345 represents 123.45
60/// - Range: -999.99 to 999.99
61///
62/// ## Valid Scalar Types
63///
64/// The underlying storage uses these native types based on precision:
65/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
66/// - Type selection is automatic based on the required precision
67///
68/// # Examples
69///
70/// ```
71/// use vortex_array::arrays::DecimalArray;
72/// use vortex_dtype::DecimalDType;
73/// use vortex_buffer::{buffer, Buffer};
74/// use vortex_array::validity::Validity;
75///
76/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
77/// let decimal_dtype = DecimalDType::new(5, 2);
78/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
79/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
80///
81/// assert_eq!(array.precision(), 5);
82/// assert_eq!(array.scale(), 2);
83/// assert_eq!(array.len(), 3);
84/// ```
85#[derive(Clone, Debug)]
86pub struct DecimalArray {
87    pub(super) dtype: DType,
88    pub(super) values: ByteBuffer,
89    pub(super) values_type: DecimalType,
90    pub(super) validity: Validity,
91    pub(super) stats_set: ArrayStats,
92}
93
94impl DecimalArray {
95    /// Creates a new [`DecimalArray`].
96    ///
97    /// # Panics
98    ///
99    /// Panics if the provided components do not satisfy the invariants documented in
100    /// [`DecimalArray::new_unchecked`].
101    pub fn new<T: NativeDecimalType>(
102        buffer: Buffer<T>,
103        decimal_dtype: DecimalDType,
104        validity: Validity,
105    ) -> Self {
106        Self::try_new(buffer, decimal_dtype, validity)
107            .vortex_expect("DecimalArray construction failed")
108    }
109
110    /// Constructs a new `DecimalArray`.
111    ///
112    /// See [`DecimalArray::new_unchecked`] for more information.
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if the provided components do not satisfy the invariants documented in
117    /// [`DecimalArray::new_unchecked`].
118    pub fn try_new<T: NativeDecimalType>(
119        buffer: Buffer<T>,
120        decimal_dtype: DecimalDType,
121        validity: Validity,
122    ) -> VortexResult<Self> {
123        Self::validate(&buffer, &validity)?;
124
125        // SAFETY: validate ensures all invariants are met.
126        Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) })
127    }
128
129    /// Creates a new [`DecimalArray`] without validation from these components:
130    ///
131    /// * `buffer` is a typed buffer containing the decimal values.
132    /// * `decimal_dtype` specifies the decimal precision and scale.
133    /// * `validity` holds the null values.
134    ///
135    /// # Safety
136    ///
137    /// The caller must ensure all of the following invariants are satisfied:
138    ///
139    /// - All non-null values in `buffer` must be representable within the specified precision.
140    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
141    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
142    pub unsafe fn new_unchecked<T: NativeDecimalType>(
143        buffer: Buffer<T>,
144        decimal_dtype: DecimalDType,
145        validity: Validity,
146    ) -> Self {
147        #[cfg(debug_assertions)]
148        Self::validate(&buffer, &validity)
149            .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
150
151        Self {
152            values: buffer.into_byte_buffer(),
153            values_type: T::DECIMAL_TYPE,
154            dtype: DType::Decimal(decimal_dtype, validity.nullability()),
155            validity,
156            stats_set: Default::default(),
157        }
158    }
159
160    /// Validates the components that would be used to create a [`DecimalArray`].
161    ///
162    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
163    pub fn validate<T: NativeDecimalType>(
164        buffer: &Buffer<T>,
165        validity: &Validity,
166    ) -> VortexResult<()> {
167        if let Some(len) = validity.maybe_len() {
168            vortex_ensure!(
169                buffer.len() == len,
170                "Buffer and validity length mismatch: buffer={}, validity={}",
171                buffer.len(),
172                len,
173            );
174        }
175
176        Ok(())
177    }
178
179    /// Returns the underlying [`ByteBuffer`] of the array.
180    pub fn byte_buffer(&self) -> ByteBuffer {
181        self.values.clone()
182    }
183
184    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
185        if self.values_type != T::DECIMAL_TYPE {
186            vortex_panic!(
187                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
188                T::DECIMAL_TYPE,
189                self.values_type,
190            );
191        }
192        Buffer::<T>::from_byte_buffer(self.values.clone())
193    }
194
195    /// Returns the decimal type information
196    pub fn decimal_dtype(&self) -> DecimalDType {
197        if let DType::Decimal(decimal_dtype, _) = self.dtype {
198            decimal_dtype
199        } else {
200            vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
201        }
202    }
203
204    pub fn values_type(&self) -> DecimalType {
205        self.values_type
206    }
207
208    pub fn precision(&self) -> u8 {
209        self.decimal_dtype().precision()
210    }
211
212    pub fn scale(&self) -> i8 {
213        self.decimal_dtype().scale()
214    }
215
216    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
217        iter: I,
218        decimal_dtype: DecimalDType,
219    ) -> Self {
220        let iter = iter.into_iter();
221
222        Self::new(
223            BufferMut::from_iter(iter).freeze(),
224            decimal_dtype,
225            Validity::NonNullable,
226        )
227    }
228
229    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
230        iter: I,
231        decimal_dtype: DecimalDType,
232    ) -> Self {
233        let iter = iter.into_iter();
234        let mut values = BufferMut::with_capacity(iter.size_hint().0);
235        let mut validity = BitBufferMut::with_capacity(values.capacity());
236
237        for i in iter {
238            match i {
239                None => {
240                    validity.append(false);
241                    values.push(T::default());
242                }
243                Some(e) => {
244                    validity.append(true);
245                    values.push(e);
246                }
247            }
248        }
249        Self::new(
250            values.freeze(),
251            decimal_dtype,
252            Validity::from(validity.freeze()),
253        )
254    }
255
256    #[expect(
257        clippy::cognitive_complexity,
258        reason = "complexity from nested match_each_* macros"
259    )]
260    pub fn patch(self, patches: &Patches) -> Self {
261        let offset = patches.offset();
262        let patch_indices = patches.indices().to_primitive();
263        let patch_values = patches.values().to_decimal();
264
265        let patched_validity = self.validity().clone().patch(
266            self.len(),
267            offset,
268            patch_indices.as_ref(),
269            patch_values.validity(),
270        );
271        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
272
273        match_each_integer_ptype!(patch_indices.ptype(), |I| {
274            let patch_indices = patch_indices.as_slice::<I>();
275            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
276                let patch_values = patch_values.buffer::<PatchDVT>();
277                match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
278                    let buffer = self.buffer::<ValuesDVT>().into_mut();
279                    patch_typed(
280                        buffer,
281                        self.decimal_dtype(),
282                        patch_indices,
283                        offset,
284                        patch_values,
285                        patched_validity,
286                    )
287                })
288            })
289        })
290    }
291}
292
293fn patch_typed<I, ValuesDVT, PatchDVT>(
294    mut buffer: BufferMut<ValuesDVT>,
295    decimal_dtype: DecimalDType,
296    patch_indices: &[I],
297    patch_indices_offset: usize,
298    patch_values: Buffer<PatchDVT>,
299    patched_validity: Validity,
300) -> DecimalArray
301where
302    I: IntegerPType,
303    PatchDVT: NativeDecimalType,
304    ValuesDVT: NativeDecimalType,
305{
306    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
307        vortex_panic!(
308            "patch_typed: {:?} cannot represent every value in {}.",
309            ValuesDVT::DECIMAL_TYPE,
310            decimal_dtype
311        )
312    }
313
314    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
315        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
316            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
317        );
318    }
319
320    DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
321}