Skip to main content

vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13
14use crate::ExecutionCtx;
15use crate::arrays::PrimitiveArray;
16use crate::buffer::BufferHandle;
17use crate::dtype::BigCast;
18use crate::dtype::DType;
19use crate::dtype::DecimalDType;
20use crate::dtype::DecimalType;
21use crate::dtype::IntegerPType;
22use crate::dtype::NativeDecimalType;
23use crate::match_each_decimal_value_type;
24use crate::match_each_integer_ptype;
25use crate::patches::Patches;
26use crate::stats::ArrayStats;
27use crate::validity::Validity;
28use crate::vtable::ValidityHelper;
29
30/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
31///
32/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
33/// financial and scientific computations where floating-point precision loss is unacceptable.
34///
35/// ## Storage Format
36///
37/// Decimals are stored as scaled integers in a supported scalar value type.
38///
39/// The precisions supported for each scalar type are:
40/// - **i8**: precision 1-2 digits
41/// - **i16**: precision 3-4 digits
42/// - **i32**: precision 5-9 digits
43/// - **i64**: precision 10-18 digits
44/// - **i128**: precision 19-38 digits
45/// - **i256**: precision 39-76 digits
46///
47/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
48/// values with precision that does not match this exactly. For example, a valid DecimalArray with
49/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
50///
51/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
52/// `Buffer<i256>`.
53///
54/// ## Precision and Scale
55///
56/// - **Precision**: Total number of significant digits (1-76, u8 range)
57/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
58/// - **Value**: `stored_integer / 10^scale`
59///
60/// For example, with precision=5 and scale=2:
61/// - Stored value 12345 represents 123.45
62/// - Range: -999.99 to 999.99
63///
64/// ## Valid Scalar Types
65///
66/// The underlying storage uses these native types based on precision:
67/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
68/// - Type selection is automatic based on the required precision
69///
70/// # Examples
71///
72/// ```
73/// use vortex_array::arrays::DecimalArray;
74/// use vortex_array::dtype::DecimalDType;
75/// use vortex_buffer::{buffer, Buffer};
76/// use vortex_array::validity::Validity;
77///
78/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
79/// let decimal_dtype = DecimalDType::new(5, 2);
80/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
81/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
82///
83/// assert_eq!(array.precision(), 5);
84/// assert_eq!(array.scale(), 2);
85/// assert_eq!(array.len(), 3);
86/// ```
87#[derive(Clone, Debug)]
88pub struct DecimalArray {
89    pub(super) dtype: DType,
90    pub(super) values: BufferHandle,
91    pub(super) values_type: DecimalType,
92    pub(super) validity: Validity,
93    pub(super) stats_set: ArrayStats,
94}
95
96pub struct DecimalArrayParts {
97    pub decimal_dtype: DecimalDType,
98    pub values: BufferHandle,
99    pub values_type: DecimalType,
100    pub validity: Validity,
101}
102
103impl DecimalArray {
104    /// Creates a new [`DecimalArray`] using a host-native buffer.
105    ///
106    /// # Panics
107    ///
108    /// Panics if the provided components do not satisfy the invariants documented in
109    /// [`DecimalArray::new_unchecked`].
110    pub fn new<T: NativeDecimalType>(
111        buffer: Buffer<T>,
112        decimal_dtype: DecimalDType,
113        validity: Validity,
114    ) -> Self {
115        Self::try_new(buffer, decimal_dtype, validity)
116            .vortex_expect("DecimalArray construction failed")
117    }
118
119    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] of values that may live in
120    /// host or device memory.
121    ///
122    /// # Panics
123    ///
124    /// Panics if the provided components do not satisfy the invariants documented in
125    /// [`DecimalArray::new_unchecked`].
126    pub fn new_handle(
127        values: BufferHandle,
128        values_type: DecimalType,
129        decimal_dtype: DecimalDType,
130        validity: Validity,
131    ) -> Self {
132        Self::try_new_handle(values, values_type, decimal_dtype, validity)
133            .vortex_expect("DecimalArray construction failed")
134    }
135
136    /// Constructs a new `DecimalArray`.
137    ///
138    /// See [`DecimalArray::new_unchecked`] for more information.
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if the provided components do not satisfy the invariants documented in
143    /// [`DecimalArray::new_unchecked`].
144    pub fn try_new<T: NativeDecimalType>(
145        buffer: Buffer<T>,
146        decimal_dtype: DecimalDType,
147        validity: Validity,
148    ) -> VortexResult<Self> {
149        let values = BufferHandle::new_host(buffer.into_byte_buffer());
150        let values_type = T::DECIMAL_TYPE;
151
152        Self::try_new_handle(values, values_type, decimal_dtype, validity)
153    }
154
155    /// Constructs a new `DecimalArray` with validation from a [`BufferHandle`].
156    ///
157    /// This pathway allows building new decimal arrays that may come from host or device memory.
158    ///
159    /// # Errors
160    ///
161    /// See [`DecimalArray::new_unchecked`] for invariants that are checked.
162    pub fn try_new_handle(
163        values: BufferHandle,
164        values_type: DecimalType,
165        decimal_dtype: DecimalDType,
166        validity: Validity,
167    ) -> VortexResult<Self> {
168        Self::validate(&values, values_type, &validity)?;
169
170        // SAFETY: validate ensures all invariants are met.
171        Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype, validity) })
172    }
173
174    /// Creates a new [`DecimalArray`] without validation from these components:
175    ///
176    /// * `buffer` is a typed buffer containing the decimal values.
177    /// * `decimal_dtype` specifies the decimal precision and scale.
178    /// * `validity` holds the null values.
179    ///
180    /// # Safety
181    ///
182    /// The caller must ensure all of the following invariants are satisfied:
183    ///
184    /// - All non-null values in `buffer` must be representable within the specified precision.
185    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
186    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
187    pub unsafe fn new_unchecked<T: NativeDecimalType>(
188        buffer: Buffer<T>,
189        decimal_dtype: DecimalDType,
190        validity: Validity,
191    ) -> Self {
192        // SAFETY: new_unchecked_handle inherits the safety guarantees of new_unchecked
193        unsafe {
194            Self::new_unchecked_handle(
195                BufferHandle::new_host(buffer.into_byte_buffer()),
196                T::DECIMAL_TYPE,
197                decimal_dtype,
198                validity,
199            )
200        }
201    }
202
203    /// Create a new array with decimal values backed by the given buffer handle.
204    ///
205    /// # Safety
206    ///
207    /// The caller must ensure all of the following invariants are satisfied:
208    ///
209    /// - All non-null values in `values` must be representable within the specified precision.
210    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
211    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
212    pub unsafe fn new_unchecked_handle(
213        values: BufferHandle,
214        values_type: DecimalType,
215        decimal_dtype: DecimalDType,
216        validity: Validity,
217    ) -> Self {
218        #[cfg(debug_assertions)]
219        {
220            Self::validate(&values, values_type, &validity)
221                .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
222        }
223
224        Self {
225            values,
226            values_type,
227            dtype: DType::Decimal(decimal_dtype, validity.nullability()),
228            validity,
229            stats_set: Default::default(),
230        }
231    }
232
233    /// Validates the components that would be used to create a [`DecimalArray`] from a byte buffer.
234    ///
235    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
236    fn validate(
237        buffer: &BufferHandle,
238        values_type: DecimalType,
239        validity: &Validity,
240    ) -> VortexResult<()> {
241        if let Some(validity_len) = validity.maybe_len() {
242            let expected_len = values_type.byte_width() * validity_len;
243            vortex_ensure!(
244                buffer.len() == expected_len,
245                InvalidArgument: "expected buffer of size {} bytes, was {} bytes",
246                expected_len,
247                buffer.len(),
248            );
249        }
250
251        Ok(())
252    }
253
254    /// Creates a new [`DecimalArray`] from a raw byte buffer without validation.
255    ///
256    /// # Safety
257    ///
258    /// The caller must ensure:
259    /// - The `byte_buffer` contains valid data for the specified `values_type`
260    /// - The buffer length is compatible with the `values_type` (i.e., divisible by the type size)
261    /// - All non-null values are representable within the specified precision
262    /// - If `validity` is [`Validity::Array`], its length must equal the number of elements
263    pub unsafe fn new_unchecked_from_byte_buffer(
264        byte_buffer: ByteBuffer,
265        values_type: DecimalType,
266        decimal_dtype: DecimalDType,
267        validity: Validity,
268    ) -> Self {
269        // SAFETY: inherits the same safety contract as `new_unchecked_from_byte_buffer`
270        unsafe {
271            Self::new_unchecked_handle(
272                BufferHandle::new_host(byte_buffer),
273                values_type,
274                decimal_dtype,
275                validity,
276            )
277        }
278    }
279
280    pub fn into_parts(self) -> DecimalArrayParts {
281        let decimal_dtype = self.dtype.into_decimal_opt().vortex_expect("cannot fail");
282
283        DecimalArrayParts {
284            decimal_dtype,
285            values: self.values,
286            values_type: self.values_type,
287            validity: self.validity,
288        }
289    }
290
291    /// Returns the underlying [`ByteBuffer`] of the array.
292    pub fn buffer_handle(&self) -> &BufferHandle {
293        &self.values
294    }
295
296    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
297        if self.values_type != T::DECIMAL_TYPE {
298            vortex_panic!(
299                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
300                T::DECIMAL_TYPE,
301                self.values_type,
302            );
303        }
304        Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
305    }
306
307    /// Returns the decimal type information
308    pub fn decimal_dtype(&self) -> DecimalDType {
309        if let DType::Decimal(decimal_dtype, _) = self.dtype {
310            decimal_dtype
311        } else {
312            vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
313        }
314    }
315
316    /// Return the `DecimalType` used to represent the values in the array.
317    pub fn values_type(&self) -> DecimalType {
318        self.values_type
319    }
320
321    pub fn precision(&self) -> u8 {
322        self.decimal_dtype().precision()
323    }
324
325    pub fn scale(&self) -> i8 {
326        self.decimal_dtype().scale()
327    }
328
329    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
330        iter: I,
331        decimal_dtype: DecimalDType,
332    ) -> Self {
333        let iter = iter.into_iter();
334
335        Self::new(
336            BufferMut::from_iter(iter).freeze(),
337            decimal_dtype,
338            Validity::NonNullable,
339        )
340    }
341
342    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
343        iter: I,
344        decimal_dtype: DecimalDType,
345    ) -> Self {
346        let iter = iter.into_iter();
347        let mut values = BufferMut::with_capacity(iter.size_hint().0);
348        let mut validity = BitBufferMut::with_capacity(values.capacity());
349
350        for i in iter {
351            match i {
352                None => {
353                    validity.append(false);
354                    values.push(T::default());
355                }
356                Some(e) => {
357                    validity.append(true);
358                    values.push(e);
359                }
360            }
361        }
362        Self::new(
363            values.freeze(),
364            decimal_dtype,
365            Validity::from(validity.freeze()),
366        )
367    }
368
369    #[expect(
370        clippy::cognitive_complexity,
371        reason = "complexity from nested match_each_* macros"
372    )]
373    pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
374        let offset = patches.offset();
375        let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
376        let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
377
378        let patched_validity = self.validity().clone().patch(
379            self.len(),
380            offset,
381            &patch_indices.to_array(),
382            patch_values.validity(),
383            ctx,
384        )?;
385        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
386
387        Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
388            let patch_indices = patch_indices.as_slice::<I>();
389            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
390                let patch_values = patch_values.buffer::<PatchDVT>();
391                match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
392                    let buffer = self.buffer::<ValuesDVT>().into_mut();
393                    patch_typed(
394                        buffer,
395                        self.decimal_dtype(),
396                        patch_indices,
397                        offset,
398                        patch_values,
399                        patched_validity,
400                    )
401                })
402            })
403        }))
404    }
405}
406
407fn patch_typed<I, ValuesDVT, PatchDVT>(
408    mut buffer: BufferMut<ValuesDVT>,
409    decimal_dtype: DecimalDType,
410    patch_indices: &[I],
411    patch_indices_offset: usize,
412    patch_values: Buffer<PatchDVT>,
413    patched_validity: Validity,
414) -> DecimalArray
415where
416    I: IntegerPType,
417    PatchDVT: NativeDecimalType,
418    ValuesDVT: NativeDecimalType,
419{
420    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
421        vortex_panic!(
422            "patch_typed: {:?} cannot represent every value in {}.",
423            ValuesDVT::DECIMAL_TYPE,
424            decimal_dtype
425        )
426    }
427
428    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
429        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
430            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
431        );
432    }
433
434    DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
435}