Skip to main content

vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::BigCast;
10use vortex_dtype::DType;
11use vortex_dtype::DecimalDType;
12use vortex_dtype::DecimalType;
13use vortex_dtype::IntegerPType;
14use vortex_dtype::NativeDecimalType;
15use vortex_dtype::match_each_decimal_value_type;
16use vortex_dtype::match_each_integer_ptype;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_ensure;
20use vortex_error::vortex_panic;
21
22use crate::ToCanonical;
23use crate::buffer::BufferHandle;
24use crate::patches::Patches;
25use crate::stats::ArrayStats;
26use crate::validity::Validity;
27use crate::vtable::ValidityHelper;
28
29/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
30///
31/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
32/// financial and scientific computations where floating-point precision loss is unacceptable.
33///
34/// ## Storage Format
35///
36/// Decimals are stored as scaled integers in a supported scalar value type.
37///
38/// The precisions supported for each scalar type are:
39/// - **i8**: precision 1-2 digits
40/// - **i16**: precision 3-4 digits
41/// - **i32**: precision 5-9 digits
42/// - **i64**: precision 10-18 digits
43/// - **i128**: precision 19-38 digits
44/// - **i256**: precision 39-76 digits
45///
46/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
47/// values with precision that does not match this exactly. For example, a valid DecimalArray with
48/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
49///
50/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
51/// `Buffer<i256>`.
52///
53/// ## Precision and Scale
54///
55/// - **Precision**: Total number of significant digits (1-76, u8 range)
56/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
57/// - **Value**: `stored_integer / 10^scale`
58///
59/// For example, with precision=5 and scale=2:
60/// - Stored value 12345 represents 123.45
61/// - Range: -999.99 to 999.99
62///
63/// ## Valid Scalar Types
64///
65/// The underlying storage uses these native types based on precision:
66/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
67/// - Type selection is automatic based on the required precision
68///
69/// # Examples
70///
71/// ```
72/// use vortex_array::arrays::DecimalArray;
73/// use vortex_dtype::DecimalDType;
74/// use vortex_buffer::{buffer, Buffer};
75/// use vortex_array::validity::Validity;
76///
77/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
78/// let decimal_dtype = DecimalDType::new(5, 2);
79/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
80/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
81///
82/// assert_eq!(array.precision(), 5);
83/// assert_eq!(array.scale(), 2);
84/// assert_eq!(array.len(), 3);
85/// ```
86#[derive(Clone, Debug)]
87pub struct DecimalArray {
88    pub(super) dtype: DType,
89    pub(super) values: BufferHandle,
90    pub(super) values_type: DecimalType,
91    pub(super) validity: Validity,
92    pub(super) stats_set: ArrayStats,
93}
94
95pub struct DecimalArrayParts {
96    pub decimal_dtype: DecimalDType,
97    pub values: BufferHandle,
98    pub values_type: DecimalType,
99    pub validity: Validity,
100}
101
102impl DecimalArray {
103    /// Creates a new [`DecimalArray`] using a host-native buffer.
104    ///
105    /// # Panics
106    ///
107    /// Panics if the provided components do not satisfy the invariants documented in
108    /// [`DecimalArray::new_unchecked`].
109    pub fn new<T: NativeDecimalType>(
110        buffer: Buffer<T>,
111        decimal_dtype: DecimalDType,
112        validity: Validity,
113    ) -> Self {
114        Self::try_new(buffer, decimal_dtype, validity)
115            .vortex_expect("DecimalArray construction failed")
116    }
117
118    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] of values that may live in
119    /// host or device memory.
120    ///
121    /// # Panics
122    ///
123    /// Panics if the provided components do not satisfy the invariants documented in
124    /// [`DecimalArray::new_unchecked`].
125    pub fn new_handle(
126        values: BufferHandle,
127        values_type: DecimalType,
128        decimal_dtype: DecimalDType,
129        validity: Validity,
130    ) -> Self {
131        Self::try_new_handle(values, values_type, decimal_dtype, validity)
132            .vortex_expect("DecimalArray construction failed")
133    }
134
135    /// Constructs a new `DecimalArray`.
136    ///
137    /// See [`DecimalArray::new_unchecked`] for more information.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if the provided components do not satisfy the invariants documented in
142    /// [`DecimalArray::new_unchecked`].
143    pub fn try_new<T: NativeDecimalType>(
144        buffer: Buffer<T>,
145        decimal_dtype: DecimalDType,
146        validity: Validity,
147    ) -> VortexResult<Self> {
148        let values = BufferHandle::new_host(buffer.into_byte_buffer());
149        let values_type = T::DECIMAL_TYPE;
150
151        Self::try_new_handle(values, values_type, decimal_dtype, validity)
152    }
153
154    /// Constructs a new `DecimalArray` with validation from a [`BufferHandle`].
155    ///
156    /// This pathway allows building new decimal arrays that may come from host or device memory.
157    ///
158    /// # Errors
159    ///
160    /// See [`DecimalArray::new_unchecked`] for invariants that are checked.
161    pub fn try_new_handle(
162        values: BufferHandle,
163        values_type: DecimalType,
164        decimal_dtype: DecimalDType,
165        validity: Validity,
166    ) -> VortexResult<Self> {
167        Self::validate(&values, values_type, &validity)?;
168
169        // SAFETY: validate ensures all invariants are met.
170        Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype, validity) })
171    }
172
173    /// Creates a new [`DecimalArray`] without validation from these components:
174    ///
175    /// * `buffer` is a typed buffer containing the decimal values.
176    /// * `decimal_dtype` specifies the decimal precision and scale.
177    /// * `validity` holds the null values.
178    ///
179    /// # Safety
180    ///
181    /// The caller must ensure all of the following invariants are satisfied:
182    ///
183    /// - All non-null values in `buffer` must be representable within the specified precision.
184    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
185    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
186    pub unsafe fn new_unchecked<T: NativeDecimalType>(
187        buffer: Buffer<T>,
188        decimal_dtype: DecimalDType,
189        validity: Validity,
190    ) -> Self {
191        // SAFETY: new_unchecked_handle inherits the safety guarantees of new_unchecked
192        unsafe {
193            Self::new_unchecked_handle(
194                BufferHandle::new_host(buffer.into_byte_buffer()),
195                T::DECIMAL_TYPE,
196                decimal_dtype,
197                validity,
198            )
199        }
200    }
201
202    /// Create a new array with decimal values backed by the given buffer handle.
203    ///
204    /// # Safety
205    ///
206    /// The caller must ensure all of the following invariants are satisfied:
207    ///
208    /// - All non-null values in `values` must be representable within the specified precision.
209    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
210    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
211    pub unsafe fn new_unchecked_handle(
212        values: BufferHandle,
213        values_type: DecimalType,
214        decimal_dtype: DecimalDType,
215        validity: Validity,
216    ) -> Self {
217        #[cfg(debug_assertions)]
218        {
219            Self::validate(&values, values_type, &validity)
220                .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
221        }
222
223        Self {
224            values,
225            values_type,
226            dtype: DType::Decimal(decimal_dtype, validity.nullability()),
227            validity,
228            stats_set: Default::default(),
229        }
230    }
231
232    /// Validates the components that would be used to create a [`DecimalArray`] from a byte buffer.
233    ///
234    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
235    fn validate(
236        buffer: &BufferHandle,
237        values_type: DecimalType,
238        validity: &Validity,
239    ) -> VortexResult<()> {
240        if let Some(validity_len) = validity.maybe_len() {
241            let expected_len = values_type.byte_width() * validity_len;
242            vortex_ensure!(
243                buffer.len() == expected_len,
244                InvalidArgument: "expected buffer of size {} bytes, was {} bytes",
245                expected_len,
246                buffer.len(),
247            );
248        }
249
250        Ok(())
251    }
252
253    /// Creates a new [`DecimalArray`] from a raw byte buffer without validation.
254    ///
255    /// # Safety
256    ///
257    /// The caller must ensure:
258    /// - The `byte_buffer` contains valid data for the specified `values_type`
259    /// - The buffer length is compatible with the `values_type` (i.e., divisible by the type size)
260    /// - All non-null values are representable within the specified precision
261    /// - If `validity` is [`Validity::Array`], its length must equal the number of elements
262    pub unsafe fn new_unchecked_from_byte_buffer(
263        byte_buffer: ByteBuffer,
264        values_type: DecimalType,
265        decimal_dtype: DecimalDType,
266        validity: Validity,
267    ) -> Self {
268        // SAFETY: inherits the same safety contract as `new_unchecked_from_byte_buffer`
269        unsafe {
270            Self::new_unchecked_handle(
271                BufferHandle::new_host(byte_buffer),
272                values_type,
273                decimal_dtype,
274                validity,
275            )
276        }
277    }
278
279    pub fn into_parts(self) -> DecimalArrayParts {
280        let decimal_dtype = self.dtype.into_decimal_opt().vortex_expect("cannot fail");
281
282        DecimalArrayParts {
283            decimal_dtype,
284            values: self.values,
285            values_type: self.values_type,
286            validity: self.validity,
287        }
288    }
289
290    /// Returns the underlying [`ByteBuffer`] of the array.
291    pub fn buffer_handle(&self) -> &BufferHandle {
292        &self.values
293    }
294
295    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
296        if self.values_type != T::DECIMAL_TYPE {
297            vortex_panic!(
298                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
299                T::DECIMAL_TYPE,
300                self.values_type,
301            );
302        }
303        Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
304    }
305
306    /// Returns the decimal type information
307    pub fn decimal_dtype(&self) -> DecimalDType {
308        if let DType::Decimal(decimal_dtype, _) = self.dtype {
309            decimal_dtype
310        } else {
311            vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
312        }
313    }
314
315    /// Return the `DecimalType` used to represent the values in the array.
316    pub fn values_type(&self) -> DecimalType {
317        self.values_type
318    }
319
320    pub fn precision(&self) -> u8 {
321        self.decimal_dtype().precision()
322    }
323
324    pub fn scale(&self) -> i8 {
325        self.decimal_dtype().scale()
326    }
327
328    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
329        iter: I,
330        decimal_dtype: DecimalDType,
331    ) -> Self {
332        let iter = iter.into_iter();
333
334        Self::new(
335            BufferMut::from_iter(iter).freeze(),
336            decimal_dtype,
337            Validity::NonNullable,
338        )
339    }
340
341    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
342        iter: I,
343        decimal_dtype: DecimalDType,
344    ) -> Self {
345        let iter = iter.into_iter();
346        let mut values = BufferMut::with_capacity(iter.size_hint().0);
347        let mut validity = BitBufferMut::with_capacity(values.capacity());
348
349        for i in iter {
350            match i {
351                None => {
352                    validity.append(false);
353                    values.push(T::default());
354                }
355                Some(e) => {
356                    validity.append(true);
357                    values.push(e);
358                }
359            }
360        }
361        Self::new(
362            values.freeze(),
363            decimal_dtype,
364            Validity::from(validity.freeze()),
365        )
366    }
367
368    #[expect(
369        clippy::cognitive_complexity,
370        reason = "complexity from nested match_each_* macros"
371    )]
372    pub fn patch(self, patches: &Patches) -> VortexResult<Self> {
373        let offset = patches.offset();
374        let patch_indices = patches.indices().to_primitive();
375        let patch_values = patches.values().to_decimal();
376
377        let patched_validity = self.validity().clone().patch(
378            self.len(),
379            offset,
380            patch_indices.as_ref(),
381            patch_values.validity(),
382        )?;
383        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
384
385        Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
386            let patch_indices = patch_indices.as_slice::<I>();
387            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
388                let patch_values = patch_values.buffer::<PatchDVT>();
389                match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
390                    let buffer = self.buffer::<ValuesDVT>().into_mut();
391                    patch_typed(
392                        buffer,
393                        self.decimal_dtype(),
394                        patch_indices,
395                        offset,
396                        patch_values,
397                        patched_validity,
398                    )
399                })
400            })
401        }))
402    }
403}
404
405fn patch_typed<I, ValuesDVT, PatchDVT>(
406    mut buffer: BufferMut<ValuesDVT>,
407    decimal_dtype: DecimalDType,
408    patch_indices: &[I],
409    patch_indices_offset: usize,
410    patch_values: Buffer<PatchDVT>,
411    patched_validity: Validity,
412) -> DecimalArray
413where
414    I: IntegerPType,
415    PatchDVT: NativeDecimalType,
416    ValuesDVT: NativeDecimalType,
417{
418    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
419        vortex_panic!(
420            "patch_typed: {:?} cannot represent every value in {}.",
421            ValuesDVT::DECIMAL_TYPE,
422            decimal_dtype
423        )
424    }
425
426    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
427        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
428            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
429        );
430    }
431
432    DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
433}