Skip to main content

vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use smallvec::smallvec;
9use vortex_buffer::Alignment;
10use vortex_buffer::BitBufferMut;
11use vortex_buffer::Buffer;
12use vortex_buffer::BufferMut;
13use vortex_buffer::ByteBuffer;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18
19use crate::ArrayRef;
20use crate::ArraySlots;
21use crate::ExecutionCtx;
22use crate::IntoArray;
23use crate::array::Array;
24use crate::array::ArrayParts;
25use crate::array::TypedArrayRef;
26use crate::array::child_to_validity;
27use crate::array::validity_to_child;
28use crate::arrays::Decimal;
29use crate::arrays::DecimalArray;
30use crate::arrays::PrimitiveArray;
31use crate::arrays::primitive::PrimitiveArrayExt;
32use crate::buffer::BufferHandle;
33use crate::dtype::BigCast;
34use crate::dtype::DType;
35use crate::dtype::DecimalDType;
36use crate::dtype::DecimalType;
37use crate::dtype::IntegerPType;
38use crate::dtype::NativeDecimalType;
39use crate::dtype::Nullability;
40use crate::match_each_decimal_value_type;
41use crate::match_each_unsigned_integer_ptype;
42use crate::patches::Patches;
43use crate::validity::Validity;
44
45/// The validity bitmap indicating which elements are non-null.
46pub(super) const VALIDITY_SLOT: usize = 0;
47pub(super) const NUM_SLOTS: usize = 1;
48pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
49
50/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
51///
52/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
53/// financial and scientific computations where floating-point precision loss is unacceptable.
54///
55/// ## Storage Format
56///
57/// Decimals are stored as scaled integers in a supported scalar value type.
58///
59/// The precisions supported for each scalar type are:
60/// - **i8**: precision 1-2 digits
61/// - **i16**: precision 3-4 digits
62/// - **i32**: precision 5-9 digits
63/// - **i64**: precision 10-18 digits
64/// - **i128**: precision 19-38 digits
65/// - **i256**: precision 39-76 digits
66///
67/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
68/// values with precision that does not match this exactly. For example, a valid DecimalArray with
69/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
70///
71/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
72/// `Buffer<i256>`.
73///
74/// ## Precision and Scale
75///
76/// - **Precision**: Total number of significant digits (1-76, u8 range)
77/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
78/// - **Value**: `stored_integer / 10^scale`
79///
80/// For example, with precision=5 and scale=2:
81/// - Stored value 12345 represents 123.45
82/// - Range: -999.99 to 999.99
83///
84/// ## Valid Scalar Types
85///
86/// The underlying storage uses these native types based on precision:
87/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
88/// - Type selection is automatic based on the required precision
89///
90/// # Examples
91///
92/// ```
93/// use vortex_array::arrays::DecimalArray;
94/// use vortex_array::dtype::DecimalDType;
95/// use vortex_buffer::{buffer, Buffer};
96/// use vortex_array::validity::Validity;
97///
98/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
99/// let decimal_dtype = DecimalDType::new(5, 2);
100/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
101/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
102///
103/// assert_eq!(array.precision(), 5);
104/// assert_eq!(array.scale(), 2);
105/// assert_eq!(array.len(), 3);
106/// ```
107#[derive(Clone, Debug)]
108pub struct DecimalData {
109    pub(super) decimal_dtype: DecimalDType,
110    pub(super) values: BufferHandle,
111    pub(super) values_type: DecimalType,
112}
113
114impl Display for DecimalData {
115    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116        write!(
117            f,
118            "decimal_dtype: {}, values_type: {}",
119            self.decimal_dtype, self.values_type
120        )
121    }
122}
123
124pub struct DecimalDataParts {
125    pub decimal_dtype: DecimalDType,
126    pub values: BufferHandle,
127    pub values_type: DecimalType,
128    pub validity: Validity,
129}
130
131pub trait DecimalArrayExt: TypedArrayRef<Decimal> {
132    fn decimal_dtype(&self) -> DecimalDType {
133        match self.as_ref().dtype() {
134            DType::Decimal(decimal_dtype, _) => *decimal_dtype,
135            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
136        }
137    }
138
139    fn nullability(&self) -> Nullability {
140        match self.as_ref().dtype() {
141            DType::Decimal(_, nullability) => *nullability,
142            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
143        }
144    }
145
146    fn validity_child(&self) -> Option<&ArrayRef> {
147        self.as_ref().slots()[VALIDITY_SLOT].as_ref()
148    }
149
150    fn validity(&self) -> Validity {
151        child_to_validity(
152            self.as_ref().slots()[VALIDITY_SLOT].as_ref(),
153            self.nullability(),
154        )
155    }
156
157    fn values_type(&self) -> DecimalType {
158        self.values_type
159    }
160
161    fn precision(&self) -> u8 {
162        self.decimal_dtype().precision()
163    }
164
165    fn scale(&self) -> i8 {
166        self.decimal_dtype().scale()
167    }
168
169    fn buffer_handle(&self) -> &BufferHandle {
170        &self.values
171    }
172
173    fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
174        DecimalData::buffer::<T>(self)
175    }
176}
177impl<T: TypedArrayRef<Decimal>> DecimalArrayExt for T {}
178
179impl DecimalData {
180    /// Build the slots vector for this array.
181    pub(super) fn make_slots(validity: &Validity, len: usize) -> ArraySlots {
182        smallvec![validity_to_child(validity, len)]
183    }
184
185    /// Creates a new [`DecimalArray`] using a host-native buffer.
186    ///
187    /// # Panics
188    ///
189    /// Panics if the provided components do not satisfy the invariants documented in
190    /// [`DecimalArray::new_unchecked`].
191    pub fn new<T: NativeDecimalType>(buffer: Buffer<T>, decimal_dtype: DecimalDType) -> Self {
192        Self::try_new(buffer, decimal_dtype).vortex_expect("DecimalArray construction failed")
193    }
194
195    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] of values that may live in
196    /// host or device memory.
197    ///
198    /// # Panics
199    ///
200    /// Panics if the provided components do not satisfy the invariants documented in
201    /// [`DecimalArray::new_unchecked`].
202    pub fn new_handle(
203        values: BufferHandle,
204        values_type: DecimalType,
205        decimal_dtype: DecimalDType,
206    ) -> Self {
207        Self::try_new_handle(values, values_type, decimal_dtype)
208            .vortex_expect("DecimalArray construction failed")
209    }
210
211    /// Constructs a new `DecimalArray`.
212    ///
213    /// See [`DecimalArray::new_unchecked`] for more information.
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if the provided components do not satisfy the invariants documented in
218    /// [`DecimalArray::new_unchecked`].
219    pub fn try_new<T: NativeDecimalType>(
220        buffer: Buffer<T>,
221        decimal_dtype: DecimalDType,
222    ) -> VortexResult<Self> {
223        let values = BufferHandle::new_host(buffer.into_byte_buffer());
224        let values_type = T::DECIMAL_TYPE;
225
226        Self::try_new_handle(values, values_type, decimal_dtype)
227    }
228
229    /// Constructs a new `DecimalArray` with validation from a [`BufferHandle`].
230    ///
231    /// This pathway allows building new decimal arrays that may come from host or device memory.
232    ///
233    /// # Errors
234    ///
235    /// See [`DecimalArray::new_unchecked`] for invariants that are checked.
236    pub fn try_new_handle(
237        values: BufferHandle,
238        values_type: DecimalType,
239        decimal_dtype: DecimalDType,
240    ) -> VortexResult<Self> {
241        Self::validate(&values, values_type)?;
242
243        // SAFETY: validate ensures all invariants are met.
244        Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype) })
245    }
246
247    /// Creates a new [`DecimalArray`] without validation from these components:
248    ///
249    /// * `buffer` is a typed buffer containing the decimal values.
250    /// * `decimal_dtype` specifies the decimal precision and scale.
251    /// * `validity` holds the null values.
252    ///
253    /// # Safety
254    ///
255    /// The caller must ensure all of the following invariants are satisfied:
256    ///
257    /// - All non-null values in `buffer` must be representable within the specified precision.
258    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
259    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
260    pub unsafe fn new_unchecked<T: NativeDecimalType>(
261        buffer: Buffer<T>,
262        decimal_dtype: DecimalDType,
263    ) -> Self {
264        // SAFETY: new_unchecked_handle inherits the safety guarantees of new_unchecked
265        unsafe {
266            Self::new_unchecked_handle(
267                BufferHandle::new_host(buffer.into_byte_buffer()),
268                T::DECIMAL_TYPE,
269                decimal_dtype,
270            )
271        }
272    }
273
274    /// Create a new array with decimal values backed by the given buffer handle.
275    ///
276    /// # Safety
277    ///
278    /// The caller must ensure all of the following invariants are satisfied:
279    ///
280    /// - All non-null values in `values` must be representable within the specified precision.
281    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
282    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
283    pub unsafe fn new_unchecked_handle(
284        values: BufferHandle,
285        values_type: DecimalType,
286        decimal_dtype: DecimalDType,
287    ) -> Self {
288        Self {
289            decimal_dtype,
290            values,
291            values_type,
292        }
293    }
294
295    /// Validates the components that would be used to create a [`DecimalArray`] from a byte buffer.
296    ///
297    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
298    fn validate(buffer: &BufferHandle, values_type: DecimalType) -> VortexResult<()> {
299        let byte_width = values_type.byte_width();
300        vortex_ensure!(
301            buffer.len().is_multiple_of(byte_width),
302            InvalidArgument: "decimal buffer size {} is not divisible by element width {}",
303            buffer.len(),
304            byte_width,
305        );
306        match_each_decimal_value_type!(values_type, |D| {
307            vortex_ensure!(
308                buffer.is_aligned_to(Alignment::of::<D>()),
309                InvalidArgument: "decimal buffer alignment {:?} is invalid for values type {:?}",
310                buffer.alignment(),
311                D::DECIMAL_TYPE,
312            );
313            Ok::<(), vortex_error::VortexError>(())
314        })?;
315        Ok(())
316    }
317
318    /// Creates a new [`DecimalArray`] from a raw byte buffer without validation.
319    ///
320    /// # Safety
321    ///
322    /// The caller must ensure:
323    /// - The `byte_buffer` contains valid data for the specified `values_type`
324    /// - The buffer length is compatible with the `values_type` (i.e., divisible by the type size)
325    /// - All non-null values are representable within the specified precision
326    /// - If `validity` is [`Validity::Array`], its length must equal the number of elements
327    pub unsafe fn new_unchecked_from_byte_buffer(
328        byte_buffer: ByteBuffer,
329        values_type: DecimalType,
330        decimal_dtype: DecimalDType,
331    ) -> Self {
332        // SAFETY: inherits the same safety contract as `new_unchecked_from_byte_buffer`
333        unsafe {
334            Self::new_unchecked_handle(
335                BufferHandle::new_host(byte_buffer),
336                values_type,
337                decimal_dtype,
338            )
339        }
340    }
341
342    /// Returns the length of this array.
343    pub fn len(&self) -> usize {
344        self.values.len() / self.values_type.byte_width()
345    }
346
347    /// Returns `true` if this array is empty.
348    pub fn is_empty(&self) -> bool {
349        self.len() == 0
350    }
351
352    /// Returns the underlying [`ByteBuffer`] of the array.
353    pub fn buffer_handle(&self) -> &BufferHandle {
354        &self.values
355    }
356
357    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
358        if self.values_type != T::DECIMAL_TYPE {
359            vortex_panic!(
360                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
361                T::DECIMAL_TYPE,
362                self.values_type,
363            );
364        }
365        Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
366    }
367
368    /// Return the `DecimalType` used to represent the values in the array.
369    pub fn values_type(&self) -> DecimalType {
370        self.values_type
371    }
372
373    /// Returns the decimal type information.
374    pub fn decimal_dtype(&self) -> DecimalDType {
375        self.decimal_dtype
376    }
377
378    pub fn precision(&self) -> u8 {
379        self.decimal_dtype.precision()
380    }
381
382    pub fn scale(&self) -> i8 {
383        self.decimal_dtype.scale()
384    }
385}
386
387impl Array<Decimal> {
388    pub fn into_data_parts(self) -> DecimalDataParts {
389        let validity = DecimalArrayExt::validity(&self);
390        let decimal_dtype = DecimalArrayExt::decimal_dtype(&self);
391        let data = self.into_data();
392        DecimalDataParts {
393            decimal_dtype,
394            values: data.values,
395            values_type: data.values_type,
396            validity,
397        }
398    }
399}
400
401impl Array<Decimal> {
402    /// Creates a new [`DecimalArray`] using a host-native buffer.
403    pub fn new<T: NativeDecimalType>(
404        buffer: Buffer<T>,
405        decimal_dtype: DecimalDType,
406        validity: Validity,
407    ) -> Self {
408        Self::try_new(buffer, decimal_dtype, validity)
409            .vortex_expect("DecimalArray construction failed")
410    }
411
412    /// Creates a new [`DecimalArray`] without validation.
413    ///
414    /// # Safety
415    ///
416    /// See [`DecimalData::new_unchecked`].
417    pub unsafe fn new_unchecked<T: NativeDecimalType>(
418        buffer: Buffer<T>,
419        decimal_dtype: DecimalDType,
420        validity: Validity,
421    ) -> Self {
422        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
423        let len = buffer.len();
424        let slots = DecimalData::make_slots(&validity, len);
425        let data = unsafe { DecimalData::new_unchecked(buffer, decimal_dtype) };
426        unsafe {
427            Array::from_parts_unchecked(
428                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
429            )
430        }
431    }
432
433    /// Creates a new [`DecimalArray`] from a host-native buffer with validation.
434    pub fn try_new<T: NativeDecimalType>(
435        buffer: Buffer<T>,
436        decimal_dtype: DecimalDType,
437        validity: Validity,
438    ) -> VortexResult<Self> {
439        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
440        let len = buffer.len();
441        let slots = DecimalData::make_slots(&validity, len);
442        let data = DecimalData::try_new(buffer, decimal_dtype)?;
443        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
444    }
445
446    /// Creates a new [`DecimalArray`] from an iterator of values.
447    #[expect(
448        clippy::same_name_method,
449        reason = "intentionally named from_iter like Iterator::from_iter"
450    )]
451    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
452        iter: I,
453        decimal_dtype: DecimalDType,
454    ) -> Self {
455        Self::new(
456            BufferMut::from_iter(iter).freeze(),
457            decimal_dtype,
458            Validity::NonNullable,
459        )
460    }
461
462    /// Creates a new [`DecimalArray`] from an iterator of optional values.
463    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
464        iter: I,
465        decimal_dtype: DecimalDType,
466    ) -> Self {
467        let iter = iter.into_iter();
468        let mut values = BufferMut::with_capacity(iter.size_hint().0);
469        let mut validity = BitBufferMut::with_capacity(values.capacity());
470
471        for value in iter {
472            match value {
473                Some(value) => {
474                    values.push(value);
475                    validity.append(true);
476                }
477                None => {
478                    values.push(T::default());
479                    validity.append(false);
480                }
481            }
482        }
483
484        Self::new(
485            values.freeze(),
486            decimal_dtype,
487            Validity::from(validity.freeze()),
488        )
489    }
490
491    /// Creates a new [`DecimalArray`] from a [`BufferHandle`].
492    pub fn new_handle(
493        values: BufferHandle,
494        values_type: DecimalType,
495        decimal_dtype: DecimalDType,
496        validity: Validity,
497    ) -> Self {
498        Self::try_new_handle(values, values_type, decimal_dtype, validity)
499            .vortex_expect("DecimalArray construction failed")
500    }
501
502    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] with validation.
503    pub fn try_new_handle(
504        values: BufferHandle,
505        values_type: DecimalType,
506        decimal_dtype: DecimalDType,
507        validity: Validity,
508    ) -> VortexResult<Self> {
509        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
510        let len = values.len() / values_type.byte_width();
511        let slots = DecimalData::make_slots(&validity, len);
512        let data = DecimalData::try_new_handle(values, values_type, decimal_dtype)?;
513        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
514    }
515
516    /// Creates a new [`DecimalArray`] without validation from a [`BufferHandle`].
517    ///
518    /// # Safety
519    ///
520    /// See [`DecimalData::new_unchecked_handle`].
521    pub unsafe fn new_unchecked_handle(
522        values: BufferHandle,
523        values_type: DecimalType,
524        decimal_dtype: DecimalDType,
525        validity: Validity,
526    ) -> Self {
527        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
528        let len = values.len() / values_type.byte_width();
529        let slots = DecimalData::make_slots(&validity, len);
530        let data = unsafe { DecimalData::new_unchecked_handle(values, values_type, decimal_dtype) };
531        unsafe {
532            Array::from_parts_unchecked(
533                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
534            )
535        }
536    }
537
538    #[expect(
539        clippy::cognitive_complexity,
540        reason = "patching depends on both patch and value physical types"
541    )]
542    pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
543        let offset = patches.offset();
544        let dtype = self.dtype().clone();
545        let len = self.len();
546        let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
547        let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
548
549        let patch_validity = patch_values.validity()?;
550        let patched_validity = self.validity()?.patch(
551            self.len(),
552            offset,
553            &patch_indices.clone().into_array(),
554            &patch_validity,
555            ctx,
556        )?;
557        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
558
559        let data = self.into_data();
560        // Patch indices are non-negative; reinterpret to unsigned so this dispatches over 4 widths
561        // instead of 8 (the decimal value-type dimensions are unaffected).
562        let patch_indices_unsigned =
563            patch_indices.reinterpret_cast(patch_indices.ptype().to_unsigned());
564        let data = match_each_unsigned_integer_ptype!(patch_indices_unsigned.ptype(), |I| {
565            let patch_indices = patch_indices_unsigned.as_slice::<I>();
566            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
567                let patch_values = patch_values.buffer::<PatchDVT>();
568                match_each_decimal_value_type!(data.values_type(), |ValuesDVT| {
569                    let buffer = data.buffer::<ValuesDVT>().into_mut();
570                    patch_typed(
571                        buffer,
572                        data.decimal_dtype(),
573                        patch_indices,
574                        offset,
575                        patch_values,
576                    )
577                })
578            })
579        });
580        let slots = DecimalData::make_slots(&patched_validity, len);
581        Ok(unsafe {
582            Array::from_parts_unchecked(
583                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
584            )
585        })
586    }
587}
588
589fn patch_typed<I, ValuesDVT, PatchDVT>(
590    mut buffer: BufferMut<ValuesDVT>,
591    decimal_dtype: DecimalDType,
592    patch_indices: &[I],
593    patch_indices_offset: usize,
594    patch_values: Buffer<PatchDVT>,
595) -> DecimalData
596where
597    I: IntegerPType,
598    PatchDVT: NativeDecimalType,
599    ValuesDVT: NativeDecimalType,
600{
601    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
602        vortex_panic!(
603            "patch_typed: {:?} cannot represent every value in {}.",
604            ValuesDVT::DECIMAL_TYPE,
605            decimal_dtype
606        )
607    }
608
609    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
610        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
611            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
612        );
613    }
614
615    DecimalData::new(buffer.freeze(), decimal_dtype)
616}