Skip to main content

vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::arrays::PrimitiveArray;
17use crate::buffer::BufferHandle;
18use crate::dtype::BigCast;
19use crate::dtype::DType;
20use crate::dtype::DecimalDType;
21use crate::dtype::DecimalType;
22use crate::dtype::IntegerPType;
23use crate::dtype::NativeDecimalType;
24use crate::match_each_decimal_value_type;
25use crate::match_each_integer_ptype;
26use crate::patches::Patches;
27use crate::stats::ArrayStats;
28use crate::validity::Validity;
29use crate::vtable::ValidityHelper;
30
31/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
32///
33/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
34/// financial and scientific computations where floating-point precision loss is unacceptable.
35///
36/// ## Storage Format
37///
38/// Decimals are stored as scaled integers in a supported scalar value type.
39///
40/// The precisions supported for each scalar type are:
41/// - **i8**: precision 1-2 digits
42/// - **i16**: precision 3-4 digits
43/// - **i32**: precision 5-9 digits
44/// - **i64**: precision 10-18 digits
45/// - **i128**: precision 19-38 digits
46/// - **i256**: precision 39-76 digits
47///
48/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
49/// values with precision that does not match this exactly. For example, a valid DecimalArray with
50/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
51///
52/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
53/// `Buffer<i256>`.
54///
55/// ## Precision and Scale
56///
57/// - **Precision**: Total number of significant digits (1-76, u8 range)
58/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
59/// - **Value**: `stored_integer / 10^scale`
60///
61/// For example, with precision=5 and scale=2:
62/// - Stored value 12345 represents 123.45
63/// - Range: -999.99 to 999.99
64///
65/// ## Valid Scalar Types
66///
67/// The underlying storage uses these native types based on precision:
68/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
69/// - Type selection is automatic based on the required precision
70///
71/// # Examples
72///
73/// ```
74/// use vortex_array::arrays::DecimalArray;
75/// use vortex_array::dtype::DecimalDType;
76/// use vortex_buffer::{buffer, Buffer};
77/// use vortex_array::validity::Validity;
78///
79/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
80/// let decimal_dtype = DecimalDType::new(5, 2);
81/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
82/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
83///
84/// assert_eq!(array.precision(), 5);
85/// assert_eq!(array.scale(), 2);
86/// assert_eq!(array.len(), 3);
87/// ```
88#[derive(Clone, Debug)]
89pub struct DecimalArray {
90    pub(super) dtype: DType,
91    pub(super) values: BufferHandle,
92    pub(super) values_type: DecimalType,
93    pub(super) validity: Validity,
94    pub(super) stats_set: ArrayStats,
95}
96
97pub struct DecimalArrayParts {
98    pub decimal_dtype: DecimalDType,
99    pub values: BufferHandle,
100    pub values_type: DecimalType,
101    pub validity: Validity,
102}
103
104impl DecimalArray {
105    /// Creates a new [`DecimalArray`] using a host-native buffer.
106    ///
107    /// # Panics
108    ///
109    /// Panics if the provided components do not satisfy the invariants documented in
110    /// [`DecimalArray::new_unchecked`].
111    pub fn new<T: NativeDecimalType>(
112        buffer: Buffer<T>,
113        decimal_dtype: DecimalDType,
114        validity: Validity,
115    ) -> Self {
116        Self::try_new(buffer, decimal_dtype, validity)
117            .vortex_expect("DecimalArray construction failed")
118    }
119
120    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] of values that may live in
121    /// host or device memory.
122    ///
123    /// # Panics
124    ///
125    /// Panics if the provided components do not satisfy the invariants documented in
126    /// [`DecimalArray::new_unchecked`].
127    pub fn new_handle(
128        values: BufferHandle,
129        values_type: DecimalType,
130        decimal_dtype: DecimalDType,
131        validity: Validity,
132    ) -> Self {
133        Self::try_new_handle(values, values_type, decimal_dtype, validity)
134            .vortex_expect("DecimalArray construction failed")
135    }
136
137    /// Constructs a new `DecimalArray`.
138    ///
139    /// See [`DecimalArray::new_unchecked`] for more information.
140    ///
141    /// # Errors
142    ///
143    /// Returns an error if the provided components do not satisfy the invariants documented in
144    /// [`DecimalArray::new_unchecked`].
145    pub fn try_new<T: NativeDecimalType>(
146        buffer: Buffer<T>,
147        decimal_dtype: DecimalDType,
148        validity: Validity,
149    ) -> VortexResult<Self> {
150        let values = BufferHandle::new_host(buffer.into_byte_buffer());
151        let values_type = T::DECIMAL_TYPE;
152
153        Self::try_new_handle(values, values_type, decimal_dtype, validity)
154    }
155
156    /// Constructs a new `DecimalArray` with validation from a [`BufferHandle`].
157    ///
158    /// This pathway allows building new decimal arrays that may come from host or device memory.
159    ///
160    /// # Errors
161    ///
162    /// See [`DecimalArray::new_unchecked`] for invariants that are checked.
163    pub fn try_new_handle(
164        values: BufferHandle,
165        values_type: DecimalType,
166        decimal_dtype: DecimalDType,
167        validity: Validity,
168    ) -> VortexResult<Self> {
169        Self::validate(&values, values_type, &validity)?;
170
171        // SAFETY: validate ensures all invariants are met.
172        Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype, validity) })
173    }
174
175    /// Creates a new [`DecimalArray`] without validation from these components:
176    ///
177    /// * `buffer` is a typed buffer containing the decimal values.
178    /// * `decimal_dtype` specifies the decimal precision and scale.
179    /// * `validity` holds the null values.
180    ///
181    /// # Safety
182    ///
183    /// The caller must ensure all of the following invariants are satisfied:
184    ///
185    /// - All non-null values in `buffer` must be representable within the specified precision.
186    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
187    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
188    pub unsafe fn new_unchecked<T: NativeDecimalType>(
189        buffer: Buffer<T>,
190        decimal_dtype: DecimalDType,
191        validity: Validity,
192    ) -> Self {
193        // SAFETY: new_unchecked_handle inherits the safety guarantees of new_unchecked
194        unsafe {
195            Self::new_unchecked_handle(
196                BufferHandle::new_host(buffer.into_byte_buffer()),
197                T::DECIMAL_TYPE,
198                decimal_dtype,
199                validity,
200            )
201        }
202    }
203
204    /// Create a new array with decimal values backed by the given buffer handle.
205    ///
206    /// # Safety
207    ///
208    /// The caller must ensure all of the following invariants are satisfied:
209    ///
210    /// - All non-null values in `values` must be representable within the specified precision.
211    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
212    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
213    pub unsafe fn new_unchecked_handle(
214        values: BufferHandle,
215        values_type: DecimalType,
216        decimal_dtype: DecimalDType,
217        validity: Validity,
218    ) -> Self {
219        #[cfg(debug_assertions)]
220        {
221            Self::validate(&values, values_type, &validity)
222                .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
223        }
224
225        Self {
226            values,
227            values_type,
228            dtype: DType::Decimal(decimal_dtype, validity.nullability()),
229            validity,
230            stats_set: Default::default(),
231        }
232    }
233
234    /// Validates the components that would be used to create a [`DecimalArray`] from a byte buffer.
235    ///
236    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
237    fn validate(
238        buffer: &BufferHandle,
239        values_type: DecimalType,
240        validity: &Validity,
241    ) -> VortexResult<()> {
242        if let Some(validity_len) = validity.maybe_len() {
243            let expected_len = values_type.byte_width() * validity_len;
244            vortex_ensure!(
245                buffer.len() == expected_len,
246                InvalidArgument: "expected buffer of size {} bytes, was {} bytes",
247                expected_len,
248                buffer.len(),
249            );
250        }
251
252        Ok(())
253    }
254
255    /// Creates a new [`DecimalArray`] from a raw byte buffer without validation.
256    ///
257    /// # Safety
258    ///
259    /// The caller must ensure:
260    /// - The `byte_buffer` contains valid data for the specified `values_type`
261    /// - The buffer length is compatible with the `values_type` (i.e., divisible by the type size)
262    /// - All non-null values are representable within the specified precision
263    /// - If `validity` is [`Validity::Array`], its length must equal the number of elements
264    pub unsafe fn new_unchecked_from_byte_buffer(
265        byte_buffer: ByteBuffer,
266        values_type: DecimalType,
267        decimal_dtype: DecimalDType,
268        validity: Validity,
269    ) -> Self {
270        // SAFETY: inherits the same safety contract as `new_unchecked_from_byte_buffer`
271        unsafe {
272            Self::new_unchecked_handle(
273                BufferHandle::new_host(byte_buffer),
274                values_type,
275                decimal_dtype,
276                validity,
277            )
278        }
279    }
280
281    pub fn into_parts(self) -> DecimalArrayParts {
282        let decimal_dtype = self.dtype.into_decimal_opt().vortex_expect("cannot fail");
283
284        DecimalArrayParts {
285            decimal_dtype,
286            values: self.values,
287            values_type: self.values_type,
288            validity: self.validity,
289        }
290    }
291
292    /// Returns the underlying [`ByteBuffer`] of the array.
293    pub fn buffer_handle(&self) -> &BufferHandle {
294        &self.values
295    }
296
297    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
298        if self.values_type != T::DECIMAL_TYPE {
299            vortex_panic!(
300                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
301                T::DECIMAL_TYPE,
302                self.values_type,
303            );
304        }
305        Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
306    }
307
308    /// Returns the decimal type information
309    pub fn decimal_dtype(&self) -> DecimalDType {
310        if let DType::Decimal(decimal_dtype, _) = self.dtype {
311            decimal_dtype
312        } else {
313            vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
314        }
315    }
316
317    /// Return the `DecimalType` used to represent the values in the array.
318    pub fn values_type(&self) -> DecimalType {
319        self.values_type
320    }
321
322    pub fn precision(&self) -> u8 {
323        self.decimal_dtype().precision()
324    }
325
326    pub fn scale(&self) -> i8 {
327        self.decimal_dtype().scale()
328    }
329
330    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
331        iter: I,
332        decimal_dtype: DecimalDType,
333    ) -> Self {
334        let iter = iter.into_iter();
335
336        Self::new(
337            BufferMut::from_iter(iter).freeze(),
338            decimal_dtype,
339            Validity::NonNullable,
340        )
341    }
342
343    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
344        iter: I,
345        decimal_dtype: DecimalDType,
346    ) -> Self {
347        let iter = iter.into_iter();
348        let mut values = BufferMut::with_capacity(iter.size_hint().0);
349        let mut validity = BitBufferMut::with_capacity(values.capacity());
350
351        for i in iter {
352            match i {
353                None => {
354                    validity.append(false);
355                    values.push(T::default());
356                }
357                Some(e) => {
358                    validity.append(true);
359                    values.push(e);
360                }
361            }
362        }
363        Self::new(
364            values.freeze(),
365            decimal_dtype,
366            Validity::from(validity.freeze()),
367        )
368    }
369
370    #[expect(
371        clippy::cognitive_complexity,
372        reason = "complexity from nested match_each_* macros"
373    )]
374    pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
375        let offset = patches.offset();
376        let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
377        let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
378
379        let patched_validity = self.validity().clone().patch(
380            self.len(),
381            offset,
382            &patch_indices.clone().into_array(),
383            patch_values.validity(),
384            ctx,
385        )?;
386        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
387
388        Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
389            let patch_indices = patch_indices.as_slice::<I>();
390            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
391                let patch_values = patch_values.buffer::<PatchDVT>();
392                match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
393                    let buffer = self.buffer::<ValuesDVT>().into_mut();
394                    patch_typed(
395                        buffer,
396                        self.decimal_dtype(),
397                        patch_indices,
398                        offset,
399                        patch_values,
400                        patched_validity,
401                    )
402                })
403            })
404        }))
405    }
406}
407
408fn patch_typed<I, ValuesDVT, PatchDVT>(
409    mut buffer: BufferMut<ValuesDVT>,
410    decimal_dtype: DecimalDType,
411    patch_indices: &[I],
412    patch_indices_offset: usize,
413    patch_values: Buffer<PatchDVT>,
414    patched_validity: Validity,
415) -> DecimalArray
416where
417    I: IntegerPType,
418    PatchDVT: NativeDecimalType,
419    ValuesDVT: NativeDecimalType,
420{
421    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
422        vortex_panic!(
423            "patch_typed: {:?} cannot represent every value in {}.",
424            ValuesDVT::DECIMAL_TYPE,
425            decimal_dtype
426        )
427    }
428
429    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
430        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
431            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
432        );
433    }
434
435    DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
436}