Skip to main content

vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use vortex_buffer::Alignment;
9use vortex_buffer::BitBufferMut;
10use vortex_buffer::Buffer;
11use vortex_buffer::BufferMut;
12use vortex_buffer::ByteBuffer;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::array::Array;
22use crate::array::ArrayParts;
23use crate::array::TypedArrayRef;
24use crate::array::child_to_validity;
25use crate::array::validity_to_child;
26use crate::arrays::Decimal;
27use crate::arrays::DecimalArray;
28use crate::arrays::PrimitiveArray;
29use crate::buffer::BufferHandle;
30use crate::dtype::BigCast;
31use crate::dtype::DType;
32use crate::dtype::DecimalDType;
33use crate::dtype::DecimalType;
34use crate::dtype::IntegerPType;
35use crate::dtype::NativeDecimalType;
36use crate::dtype::Nullability;
37use crate::match_each_decimal_value_type;
38use crate::match_each_integer_ptype;
39use crate::patches::Patches;
40use crate::validity::Validity;
41
42/// The validity bitmap indicating which elements are non-null.
43pub(super) const VALIDITY_SLOT: usize = 0;
44pub(super) const NUM_SLOTS: usize = 1;
45pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
46
47/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
48///
49/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
50/// financial and scientific computations where floating-point precision loss is unacceptable.
51///
52/// ## Storage Format
53///
54/// Decimals are stored as scaled integers in a supported scalar value type.
55///
56/// The precisions supported for each scalar type are:
57/// - **i8**: precision 1-2 digits
58/// - **i16**: precision 3-4 digits
59/// - **i32**: precision 5-9 digits
60/// - **i64**: precision 10-18 digits
61/// - **i128**: precision 19-38 digits
62/// - **i256**: precision 39-76 digits
63///
64/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
65/// values with precision that does not match this exactly. For example, a valid DecimalArray with
66/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
67///
68/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
69/// `Buffer<i256>`.
70///
71/// ## Precision and Scale
72///
73/// - **Precision**: Total number of significant digits (1-76, u8 range)
74/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
75/// - **Value**: `stored_integer / 10^scale`
76///
77/// For example, with precision=5 and scale=2:
78/// - Stored value 12345 represents 123.45
79/// - Range: -999.99 to 999.99
80///
81/// ## Valid Scalar Types
82///
83/// The underlying storage uses these native types based on precision:
84/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
85/// - Type selection is automatic based on the required precision
86///
87/// # Examples
88///
89/// ```
90/// use vortex_array::arrays::DecimalArray;
91/// use vortex_array::dtype::DecimalDType;
92/// use vortex_buffer::{buffer, Buffer};
93/// use vortex_array::validity::Validity;
94///
95/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
96/// let decimal_dtype = DecimalDType::new(5, 2);
97/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
98/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
99///
100/// assert_eq!(array.precision(), 5);
101/// assert_eq!(array.scale(), 2);
102/// assert_eq!(array.len(), 3);
103/// ```
104#[derive(Clone, Debug)]
105pub struct DecimalData {
106    pub(super) decimal_dtype: DecimalDType,
107    pub(super) values: BufferHandle,
108    pub(super) values_type: DecimalType,
109}
110
111impl Display for DecimalData {
112    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113        write!(
114            f,
115            "decimal_dtype: {}, values_type: {}",
116            self.decimal_dtype, self.values_type
117        )
118    }
119}
120
121pub struct DecimalDataParts {
122    pub decimal_dtype: DecimalDType,
123    pub values: BufferHandle,
124    pub values_type: DecimalType,
125    pub validity: Validity,
126}
127
128pub trait DecimalArrayExt: TypedArrayRef<Decimal> {
129    fn decimal_dtype(&self) -> DecimalDType {
130        match self.as_ref().dtype() {
131            DType::Decimal(decimal_dtype, _) => *decimal_dtype,
132            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
133        }
134    }
135
136    fn nullability(&self) -> Nullability {
137        match self.as_ref().dtype() {
138            DType::Decimal(_, nullability) => *nullability,
139            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
140        }
141    }
142
143    fn validity_child(&self) -> Option<&ArrayRef> {
144        self.as_ref().slots()[VALIDITY_SLOT].as_ref()
145    }
146
147    fn validity(&self) -> Validity {
148        child_to_validity(
149            self.as_ref().slots()[VALIDITY_SLOT].as_ref(),
150            self.nullability(),
151        )
152    }
153
154    fn values_type(&self) -> DecimalType {
155        self.values_type
156    }
157
158    fn precision(&self) -> u8 {
159        self.decimal_dtype().precision()
160    }
161
162    fn scale(&self) -> i8 {
163        self.decimal_dtype().scale()
164    }
165
166    fn buffer_handle(&self) -> &BufferHandle {
167        &self.values
168    }
169
170    fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
171        DecimalData::buffer::<T>(self)
172    }
173}
174impl<T: TypedArrayRef<Decimal>> DecimalArrayExt for T {}
175
176impl DecimalData {
177    /// Build the slots vector for this array.
178    pub(super) fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
179        vec![validity_to_child(validity, len)]
180    }
181
182    /// Creates a new [`DecimalArray`] using a host-native buffer.
183    ///
184    /// # Panics
185    ///
186    /// Panics if the provided components do not satisfy the invariants documented in
187    /// [`DecimalArray::new_unchecked`].
188    pub fn new<T: NativeDecimalType>(buffer: Buffer<T>, decimal_dtype: DecimalDType) -> Self {
189        Self::try_new(buffer, decimal_dtype).vortex_expect("DecimalArray construction failed")
190    }
191
192    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] of values that may live in
193    /// host or device memory.
194    ///
195    /// # Panics
196    ///
197    /// Panics if the provided components do not satisfy the invariants documented in
198    /// [`DecimalArray::new_unchecked`].
199    pub fn new_handle(
200        values: BufferHandle,
201        values_type: DecimalType,
202        decimal_dtype: DecimalDType,
203    ) -> Self {
204        Self::try_new_handle(values, values_type, decimal_dtype)
205            .vortex_expect("DecimalArray construction failed")
206    }
207
208    /// Constructs a new `DecimalArray`.
209    ///
210    /// See [`DecimalArray::new_unchecked`] for more information.
211    ///
212    /// # Errors
213    ///
214    /// Returns an error if the provided components do not satisfy the invariants documented in
215    /// [`DecimalArray::new_unchecked`].
216    pub fn try_new<T: NativeDecimalType>(
217        buffer: Buffer<T>,
218        decimal_dtype: DecimalDType,
219    ) -> VortexResult<Self> {
220        let values = BufferHandle::new_host(buffer.into_byte_buffer());
221        let values_type = T::DECIMAL_TYPE;
222
223        Self::try_new_handle(values, values_type, decimal_dtype)
224    }
225
226    /// Constructs a new `DecimalArray` with validation from a [`BufferHandle`].
227    ///
228    /// This pathway allows building new decimal arrays that may come from host or device memory.
229    ///
230    /// # Errors
231    ///
232    /// See [`DecimalArray::new_unchecked`] for invariants that are checked.
233    pub fn try_new_handle(
234        values: BufferHandle,
235        values_type: DecimalType,
236        decimal_dtype: DecimalDType,
237    ) -> VortexResult<Self> {
238        Self::validate(&values, values_type)?;
239
240        // SAFETY: validate ensures all invariants are met.
241        Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype) })
242    }
243
244    /// Creates a new [`DecimalArray`] without validation from these components:
245    ///
246    /// * `buffer` is a typed buffer containing the decimal values.
247    /// * `decimal_dtype` specifies the decimal precision and scale.
248    /// * `validity` holds the null values.
249    ///
250    /// # Safety
251    ///
252    /// The caller must ensure all of the following invariants are satisfied:
253    ///
254    /// - All non-null values in `buffer` must be representable within the specified precision.
255    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
256    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
257    pub unsafe fn new_unchecked<T: NativeDecimalType>(
258        buffer: Buffer<T>,
259        decimal_dtype: DecimalDType,
260    ) -> Self {
261        // SAFETY: new_unchecked_handle inherits the safety guarantees of new_unchecked
262        unsafe {
263            Self::new_unchecked_handle(
264                BufferHandle::new_host(buffer.into_byte_buffer()),
265                T::DECIMAL_TYPE,
266                decimal_dtype,
267            )
268        }
269    }
270
271    /// Create a new array with decimal values backed by the given buffer handle.
272    ///
273    /// # Safety
274    ///
275    /// The caller must ensure all of the following invariants are satisfied:
276    ///
277    /// - All non-null values in `values` must be representable within the specified precision.
278    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
279    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
280    pub unsafe fn new_unchecked_handle(
281        values: BufferHandle,
282        values_type: DecimalType,
283        decimal_dtype: DecimalDType,
284    ) -> Self {
285        Self {
286            decimal_dtype,
287            values,
288            values_type,
289        }
290    }
291
292    /// Validates the components that would be used to create a [`DecimalArray`] from a byte buffer.
293    ///
294    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
295    fn validate(buffer: &BufferHandle, values_type: DecimalType) -> VortexResult<()> {
296        let byte_width = values_type.byte_width();
297        vortex_ensure!(
298            buffer.len().is_multiple_of(byte_width),
299            InvalidArgument: "decimal buffer size {} is not divisible by element width {}",
300            buffer.len(),
301            byte_width,
302        );
303        match_each_decimal_value_type!(values_type, |D| {
304            vortex_ensure!(
305                buffer.is_aligned_to(Alignment::of::<D>()),
306                InvalidArgument: "decimal buffer alignment {:?} is invalid for values type {:?}",
307                buffer.alignment(),
308                D::DECIMAL_TYPE,
309            );
310            Ok::<(), vortex_error::VortexError>(())
311        })?;
312        Ok(())
313    }
314
315    /// Creates a new [`DecimalArray`] from a raw byte buffer without validation.
316    ///
317    /// # Safety
318    ///
319    /// The caller must ensure:
320    /// - The `byte_buffer` contains valid data for the specified `values_type`
321    /// - The buffer length is compatible with the `values_type` (i.e., divisible by the type size)
322    /// - All non-null values are representable within the specified precision
323    /// - If `validity` is [`Validity::Array`], its length must equal the number of elements
324    pub unsafe fn new_unchecked_from_byte_buffer(
325        byte_buffer: ByteBuffer,
326        values_type: DecimalType,
327        decimal_dtype: DecimalDType,
328    ) -> Self {
329        // SAFETY: inherits the same safety contract as `new_unchecked_from_byte_buffer`
330        unsafe {
331            Self::new_unchecked_handle(
332                BufferHandle::new_host(byte_buffer),
333                values_type,
334                decimal_dtype,
335            )
336        }
337    }
338
339    /// Returns the length of this array.
340    pub fn len(&self) -> usize {
341        self.values.len() / self.values_type.byte_width()
342    }
343
344    /// Returns `true` if this array is empty.
345    pub fn is_empty(&self) -> bool {
346        self.len() == 0
347    }
348
349    /// Returns the underlying [`ByteBuffer`] of the array.
350    pub fn buffer_handle(&self) -> &BufferHandle {
351        &self.values
352    }
353
354    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
355        if self.values_type != T::DECIMAL_TYPE {
356            vortex_panic!(
357                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
358                T::DECIMAL_TYPE,
359                self.values_type,
360            );
361        }
362        Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
363    }
364
365    /// Return the `DecimalType` used to represent the values in the array.
366    pub fn values_type(&self) -> DecimalType {
367        self.values_type
368    }
369
370    /// Returns the decimal type information.
371    pub fn decimal_dtype(&self) -> DecimalDType {
372        self.decimal_dtype
373    }
374
375    pub fn precision(&self) -> u8 {
376        self.decimal_dtype.precision()
377    }
378
379    pub fn scale(&self) -> i8 {
380        self.decimal_dtype.scale()
381    }
382}
383
384impl Array<Decimal> {
385    pub fn into_data_parts(self) -> DecimalDataParts {
386        let validity = DecimalArrayExt::validity(&self);
387        let decimal_dtype = DecimalArrayExt::decimal_dtype(&self);
388        let data = self.into_data();
389        DecimalDataParts {
390            decimal_dtype,
391            values: data.values,
392            values_type: data.values_type,
393            validity,
394        }
395    }
396}
397
398impl Array<Decimal> {
399    /// Creates a new [`DecimalArray`] using a host-native buffer.
400    pub fn new<T: NativeDecimalType>(
401        buffer: Buffer<T>,
402        decimal_dtype: DecimalDType,
403        validity: Validity,
404    ) -> Self {
405        Self::try_new(buffer, decimal_dtype, validity)
406            .vortex_expect("DecimalArray construction failed")
407    }
408
409    /// Creates a new [`DecimalArray`] without validation.
410    ///
411    /// # Safety
412    ///
413    /// See [`DecimalData::new_unchecked`].
414    pub unsafe fn new_unchecked<T: NativeDecimalType>(
415        buffer: Buffer<T>,
416        decimal_dtype: DecimalDType,
417        validity: Validity,
418    ) -> Self {
419        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
420        let len = buffer.len();
421        let slots = DecimalData::make_slots(&validity, len);
422        let data = unsafe { DecimalData::new_unchecked(buffer, decimal_dtype) };
423        unsafe {
424            Array::from_parts_unchecked(
425                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
426            )
427        }
428    }
429
430    /// Creates a new [`DecimalArray`] from a host-native buffer with validation.
431    pub fn try_new<T: NativeDecimalType>(
432        buffer: Buffer<T>,
433        decimal_dtype: DecimalDType,
434        validity: Validity,
435    ) -> VortexResult<Self> {
436        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
437        let len = buffer.len();
438        let slots = DecimalData::make_slots(&validity, len);
439        let data = DecimalData::try_new(buffer, decimal_dtype)?;
440        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
441    }
442
443    /// Creates a new [`DecimalArray`] from an iterator of values.
444    #[expect(
445        clippy::same_name_method,
446        reason = "intentionally named from_iter like Iterator::from_iter"
447    )]
448    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
449        iter: I,
450        decimal_dtype: DecimalDType,
451    ) -> Self {
452        Self::new(
453            BufferMut::from_iter(iter).freeze(),
454            decimal_dtype,
455            Validity::NonNullable,
456        )
457    }
458
459    /// Creates a new [`DecimalArray`] from an iterator of optional values.
460    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
461        iter: I,
462        decimal_dtype: DecimalDType,
463    ) -> Self {
464        let iter = iter.into_iter();
465        let mut values = BufferMut::with_capacity(iter.size_hint().0);
466        let mut validity = BitBufferMut::with_capacity(values.capacity());
467
468        for value in iter {
469            match value {
470                Some(value) => {
471                    values.push(value);
472                    validity.append(true);
473                }
474                None => {
475                    values.push(T::default());
476                    validity.append(false);
477                }
478            }
479        }
480
481        Self::new(
482            values.freeze(),
483            decimal_dtype,
484            Validity::from(validity.freeze()),
485        )
486    }
487
488    /// Creates a new [`DecimalArray`] from a [`BufferHandle`].
489    pub fn new_handle(
490        values: BufferHandle,
491        values_type: DecimalType,
492        decimal_dtype: DecimalDType,
493        validity: Validity,
494    ) -> Self {
495        Self::try_new_handle(values, values_type, decimal_dtype, validity)
496            .vortex_expect("DecimalArray construction failed")
497    }
498
499    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] with validation.
500    pub fn try_new_handle(
501        values: BufferHandle,
502        values_type: DecimalType,
503        decimal_dtype: DecimalDType,
504        validity: Validity,
505    ) -> VortexResult<Self> {
506        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
507        let len = values.len() / values_type.byte_width();
508        let slots = DecimalData::make_slots(&validity, len);
509        let data = DecimalData::try_new_handle(values, values_type, decimal_dtype)?;
510        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
511    }
512
513    /// Creates a new [`DecimalArray`] without validation from a [`BufferHandle`].
514    ///
515    /// # Safety
516    ///
517    /// See [`DecimalData::new_unchecked_handle`].
518    pub unsafe fn new_unchecked_handle(
519        values: BufferHandle,
520        values_type: DecimalType,
521        decimal_dtype: DecimalDType,
522        validity: Validity,
523    ) -> Self {
524        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
525        let len = values.len() / values_type.byte_width();
526        let slots = DecimalData::make_slots(&validity, len);
527        let data = unsafe { DecimalData::new_unchecked_handle(values, values_type, decimal_dtype) };
528        unsafe {
529            Array::from_parts_unchecked(
530                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
531            )
532        }
533    }
534
535    #[expect(
536        clippy::cognitive_complexity,
537        reason = "patching depends on both patch and value physical types"
538    )]
539    pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
540        let offset = patches.offset();
541        let dtype = self.dtype().clone();
542        let len = self.len();
543        let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
544        let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
545
546        let patch_validity = patch_values.validity()?;
547        let patched_validity = self.validity()?.patch(
548            self.len(),
549            offset,
550            &patch_indices.clone().into_array(),
551            &patch_validity,
552            ctx,
553        )?;
554        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
555
556        let data = self.into_data();
557        let data = match_each_integer_ptype!(patch_indices.ptype(), |I| {
558            let patch_indices = patch_indices.as_slice::<I>();
559            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
560                let patch_values = patch_values.buffer::<PatchDVT>();
561                match_each_decimal_value_type!(data.values_type(), |ValuesDVT| {
562                    let buffer = data.buffer::<ValuesDVT>().into_mut();
563                    patch_typed(
564                        buffer,
565                        data.decimal_dtype(),
566                        patch_indices,
567                        offset,
568                        patch_values,
569                    )
570                })
571            })
572        });
573        let slots = DecimalData::make_slots(&patched_validity, len);
574        Ok(unsafe {
575            Array::from_parts_unchecked(
576                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
577            )
578        })
579    }
580}
581
582fn patch_typed<I, ValuesDVT, PatchDVT>(
583    mut buffer: BufferMut<ValuesDVT>,
584    decimal_dtype: DecimalDType,
585    patch_indices: &[I],
586    patch_indices_offset: usize,
587    patch_values: Buffer<PatchDVT>,
588) -> DecimalData
589where
590    I: IntegerPType,
591    PatchDVT: NativeDecimalType,
592    ValuesDVT: NativeDecimalType,
593{
594    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
595        vortex_panic!(
596            "patch_typed: {:?} cannot represent every value in {}.",
597            ValuesDVT::DECIMAL_TYPE,
598            decimal_dtype
599        )
600    }
601
602    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
603        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
604            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
605        );
606    }
607
608    DecimalData::new(buffer.freeze(), decimal_dtype)
609}