Skip to main content

vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use smallvec::smallvec;
9use vortex_buffer::Alignment;
10use vortex_buffer::BitBufferMut;
11use vortex_buffer::Buffer;
12use vortex_buffer::BufferMut;
13use vortex_buffer::ByteBuffer;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18
19use crate::ArrayRef;
20use crate::ArraySlots;
21use crate::ExecutionCtx;
22use crate::IntoArray;
23use crate::array::Array;
24use crate::array::ArrayParts;
25use crate::array::TypedArrayRef;
26use crate::array::child_to_validity;
27use crate::array::validity_to_child;
28use crate::arrays::Decimal;
29use crate::arrays::DecimalArray;
30use crate::arrays::PrimitiveArray;
31use crate::buffer::BufferHandle;
32use crate::dtype::BigCast;
33use crate::dtype::DType;
34use crate::dtype::DecimalDType;
35use crate::dtype::DecimalType;
36use crate::dtype::IntegerPType;
37use crate::dtype::NativeDecimalType;
38use crate::dtype::Nullability;
39use crate::match_each_decimal_value_type;
40use crate::match_each_integer_ptype;
41use crate::patches::Patches;
42use crate::validity::Validity;
43
44/// The validity bitmap indicating which elements are non-null.
45pub(super) const VALIDITY_SLOT: usize = 0;
46pub(super) const NUM_SLOTS: usize = 1;
47pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
48
49/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
50///
51/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
52/// financial and scientific computations where floating-point precision loss is unacceptable.
53///
54/// ## Storage Format
55///
56/// Decimals are stored as scaled integers in a supported scalar value type.
57///
58/// The precisions supported for each scalar type are:
59/// - **i8**: precision 1-2 digits
60/// - **i16**: precision 3-4 digits
61/// - **i32**: precision 5-9 digits
62/// - **i64**: precision 10-18 digits
63/// - **i128**: precision 19-38 digits
64/// - **i256**: precision 39-76 digits
65///
66/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
67/// values with precision that does not match this exactly. For example, a valid DecimalArray with
68/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
69///
70/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
71/// `Buffer<i256>`.
72///
73/// ## Precision and Scale
74///
75/// - **Precision**: Total number of significant digits (1-76, u8 range)
76/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
77/// - **Value**: `stored_integer / 10^scale`
78///
79/// For example, with precision=5 and scale=2:
80/// - Stored value 12345 represents 123.45
81/// - Range: -999.99 to 999.99
82///
83/// ## Valid Scalar Types
84///
85/// The underlying storage uses these native types based on precision:
86/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
87/// - Type selection is automatic based on the required precision
88///
89/// # Examples
90///
91/// ```
92/// use vortex_array::arrays::DecimalArray;
93/// use vortex_array::dtype::DecimalDType;
94/// use vortex_buffer::{buffer, Buffer};
95/// use vortex_array::validity::Validity;
96///
97/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
98/// let decimal_dtype = DecimalDType::new(5, 2);
99/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
100/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
101///
102/// assert_eq!(array.precision(), 5);
103/// assert_eq!(array.scale(), 2);
104/// assert_eq!(array.len(), 3);
105/// ```
106#[derive(Clone, Debug)]
107pub struct DecimalData {
108    pub(super) decimal_dtype: DecimalDType,
109    pub(super) values: BufferHandle,
110    pub(super) values_type: DecimalType,
111}
112
113impl Display for DecimalData {
114    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
115        write!(
116            f,
117            "decimal_dtype: {}, values_type: {}",
118            self.decimal_dtype, self.values_type
119        )
120    }
121}
122
123pub struct DecimalDataParts {
124    pub decimal_dtype: DecimalDType,
125    pub values: BufferHandle,
126    pub values_type: DecimalType,
127    pub validity: Validity,
128}
129
130pub trait DecimalArrayExt: TypedArrayRef<Decimal> {
131    fn decimal_dtype(&self) -> DecimalDType {
132        match self.as_ref().dtype() {
133            DType::Decimal(decimal_dtype, _) => *decimal_dtype,
134            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
135        }
136    }
137
138    fn nullability(&self) -> Nullability {
139        match self.as_ref().dtype() {
140            DType::Decimal(_, nullability) => *nullability,
141            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
142        }
143    }
144
145    fn validity_child(&self) -> Option<&ArrayRef> {
146        self.as_ref().slots()[VALIDITY_SLOT].as_ref()
147    }
148
149    fn validity(&self) -> Validity {
150        child_to_validity(
151            self.as_ref().slots()[VALIDITY_SLOT].as_ref(),
152            self.nullability(),
153        )
154    }
155
156    fn values_type(&self) -> DecimalType {
157        self.values_type
158    }
159
160    fn precision(&self) -> u8 {
161        self.decimal_dtype().precision()
162    }
163
164    fn scale(&self) -> i8 {
165        self.decimal_dtype().scale()
166    }
167
168    fn buffer_handle(&self) -> &BufferHandle {
169        &self.values
170    }
171
172    fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
173        DecimalData::buffer::<T>(self)
174    }
175}
176impl<T: TypedArrayRef<Decimal>> DecimalArrayExt for T {}
177
178impl DecimalData {
179    /// Build the slots vector for this array.
180    pub(super) fn make_slots(validity: &Validity, len: usize) -> ArraySlots {
181        smallvec![validity_to_child(validity, len)]
182    }
183
184    /// Creates a new [`DecimalArray`] using a host-native buffer.
185    ///
186    /// # Panics
187    ///
188    /// Panics if the provided components do not satisfy the invariants documented in
189    /// [`DecimalArray::new_unchecked`].
190    pub fn new<T: NativeDecimalType>(buffer: Buffer<T>, decimal_dtype: DecimalDType) -> Self {
191        Self::try_new(buffer, decimal_dtype).vortex_expect("DecimalArray construction failed")
192    }
193
194    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] of values that may live in
195    /// host or device memory.
196    ///
197    /// # Panics
198    ///
199    /// Panics if the provided components do not satisfy the invariants documented in
200    /// [`DecimalArray::new_unchecked`].
201    pub fn new_handle(
202        values: BufferHandle,
203        values_type: DecimalType,
204        decimal_dtype: DecimalDType,
205    ) -> Self {
206        Self::try_new_handle(values, values_type, decimal_dtype)
207            .vortex_expect("DecimalArray construction failed")
208    }
209
210    /// Constructs a new `DecimalArray`.
211    ///
212    /// See [`DecimalArray::new_unchecked`] for more information.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if the provided components do not satisfy the invariants documented in
217    /// [`DecimalArray::new_unchecked`].
218    pub fn try_new<T: NativeDecimalType>(
219        buffer: Buffer<T>,
220        decimal_dtype: DecimalDType,
221    ) -> VortexResult<Self> {
222        let values = BufferHandle::new_host(buffer.into_byte_buffer());
223        let values_type = T::DECIMAL_TYPE;
224
225        Self::try_new_handle(values, values_type, decimal_dtype)
226    }
227
228    /// Constructs a new `DecimalArray` with validation from a [`BufferHandle`].
229    ///
230    /// This pathway allows building new decimal arrays that may come from host or device memory.
231    ///
232    /// # Errors
233    ///
234    /// See [`DecimalArray::new_unchecked`] for invariants that are checked.
235    pub fn try_new_handle(
236        values: BufferHandle,
237        values_type: DecimalType,
238        decimal_dtype: DecimalDType,
239    ) -> VortexResult<Self> {
240        Self::validate(&values, values_type)?;
241
242        // SAFETY: validate ensures all invariants are met.
243        Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype) })
244    }
245
246    /// Creates a new [`DecimalArray`] without validation from these components:
247    ///
248    /// * `buffer` is a typed buffer containing the decimal values.
249    /// * `decimal_dtype` specifies the decimal precision and scale.
250    /// * `validity` holds the null values.
251    ///
252    /// # Safety
253    ///
254    /// The caller must ensure all of the following invariants are satisfied:
255    ///
256    /// - All non-null values in `buffer` must be representable within the specified precision.
257    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
258    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
259    pub unsafe fn new_unchecked<T: NativeDecimalType>(
260        buffer: Buffer<T>,
261        decimal_dtype: DecimalDType,
262    ) -> Self {
263        // SAFETY: new_unchecked_handle inherits the safety guarantees of new_unchecked
264        unsafe {
265            Self::new_unchecked_handle(
266                BufferHandle::new_host(buffer.into_byte_buffer()),
267                T::DECIMAL_TYPE,
268                decimal_dtype,
269            )
270        }
271    }
272
273    /// Create a new array with decimal values backed by the given buffer handle.
274    ///
275    /// # Safety
276    ///
277    /// The caller must ensure all of the following invariants are satisfied:
278    ///
279    /// - All non-null values in `values` must be representable within the specified precision.
280    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
281    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
282    pub unsafe fn new_unchecked_handle(
283        values: BufferHandle,
284        values_type: DecimalType,
285        decimal_dtype: DecimalDType,
286    ) -> Self {
287        Self {
288            decimal_dtype,
289            values,
290            values_type,
291        }
292    }
293
294    /// Validates the components that would be used to create a [`DecimalArray`] from a byte buffer.
295    ///
296    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
297    fn validate(buffer: &BufferHandle, values_type: DecimalType) -> VortexResult<()> {
298        let byte_width = values_type.byte_width();
299        vortex_ensure!(
300            buffer.len().is_multiple_of(byte_width),
301            InvalidArgument: "decimal buffer size {} is not divisible by element width {}",
302            buffer.len(),
303            byte_width,
304        );
305        match_each_decimal_value_type!(values_type, |D| {
306            vortex_ensure!(
307                buffer.is_aligned_to(Alignment::of::<D>()),
308                InvalidArgument: "decimal buffer alignment {:?} is invalid for values type {:?}",
309                buffer.alignment(),
310                D::DECIMAL_TYPE,
311            );
312            Ok::<(), vortex_error::VortexError>(())
313        })?;
314        Ok(())
315    }
316
317    /// Creates a new [`DecimalArray`] from a raw byte buffer without validation.
318    ///
319    /// # Safety
320    ///
321    /// The caller must ensure:
322    /// - The `byte_buffer` contains valid data for the specified `values_type`
323    /// - The buffer length is compatible with the `values_type` (i.e., divisible by the type size)
324    /// - All non-null values are representable within the specified precision
325    /// - If `validity` is [`Validity::Array`], its length must equal the number of elements
326    pub unsafe fn new_unchecked_from_byte_buffer(
327        byte_buffer: ByteBuffer,
328        values_type: DecimalType,
329        decimal_dtype: DecimalDType,
330    ) -> Self {
331        // SAFETY: inherits the same safety contract as `new_unchecked_from_byte_buffer`
332        unsafe {
333            Self::new_unchecked_handle(
334                BufferHandle::new_host(byte_buffer),
335                values_type,
336                decimal_dtype,
337            )
338        }
339    }
340
341    /// Returns the length of this array.
342    pub fn len(&self) -> usize {
343        self.values.len() / self.values_type.byte_width()
344    }
345
346    /// Returns `true` if this array is empty.
347    pub fn is_empty(&self) -> bool {
348        self.len() == 0
349    }
350
351    /// Returns the underlying [`ByteBuffer`] of the array.
352    pub fn buffer_handle(&self) -> &BufferHandle {
353        &self.values
354    }
355
356    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
357        if self.values_type != T::DECIMAL_TYPE {
358            vortex_panic!(
359                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
360                T::DECIMAL_TYPE,
361                self.values_type,
362            );
363        }
364        Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
365    }
366
367    /// Return the `DecimalType` used to represent the values in the array.
368    pub fn values_type(&self) -> DecimalType {
369        self.values_type
370    }
371
372    /// Returns the decimal type information.
373    pub fn decimal_dtype(&self) -> DecimalDType {
374        self.decimal_dtype
375    }
376
377    pub fn precision(&self) -> u8 {
378        self.decimal_dtype.precision()
379    }
380
381    pub fn scale(&self) -> i8 {
382        self.decimal_dtype.scale()
383    }
384}
385
386impl Array<Decimal> {
387    pub fn into_data_parts(self) -> DecimalDataParts {
388        let validity = DecimalArrayExt::validity(&self);
389        let decimal_dtype = DecimalArrayExt::decimal_dtype(&self);
390        let data = self.into_data();
391        DecimalDataParts {
392            decimal_dtype,
393            values: data.values,
394            values_type: data.values_type,
395            validity,
396        }
397    }
398}
399
400impl Array<Decimal> {
401    /// Creates a new [`DecimalArray`] using a host-native buffer.
402    pub fn new<T: NativeDecimalType>(
403        buffer: Buffer<T>,
404        decimal_dtype: DecimalDType,
405        validity: Validity,
406    ) -> Self {
407        Self::try_new(buffer, decimal_dtype, validity)
408            .vortex_expect("DecimalArray construction failed")
409    }
410
411    /// Creates a new [`DecimalArray`] without validation.
412    ///
413    /// # Safety
414    ///
415    /// See [`DecimalData::new_unchecked`].
416    pub unsafe fn new_unchecked<T: NativeDecimalType>(
417        buffer: Buffer<T>,
418        decimal_dtype: DecimalDType,
419        validity: Validity,
420    ) -> Self {
421        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
422        let len = buffer.len();
423        let slots = DecimalData::make_slots(&validity, len);
424        let data = unsafe { DecimalData::new_unchecked(buffer, decimal_dtype) };
425        unsafe {
426            Array::from_parts_unchecked(
427                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
428            )
429        }
430    }
431
432    /// Creates a new [`DecimalArray`] from a host-native buffer with validation.
433    pub fn try_new<T: NativeDecimalType>(
434        buffer: Buffer<T>,
435        decimal_dtype: DecimalDType,
436        validity: Validity,
437    ) -> VortexResult<Self> {
438        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
439        let len = buffer.len();
440        let slots = DecimalData::make_slots(&validity, len);
441        let data = DecimalData::try_new(buffer, decimal_dtype)?;
442        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
443    }
444
445    /// Creates a new [`DecimalArray`] from an iterator of values.
446    #[expect(
447        clippy::same_name_method,
448        reason = "intentionally named from_iter like Iterator::from_iter"
449    )]
450    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
451        iter: I,
452        decimal_dtype: DecimalDType,
453    ) -> Self {
454        Self::new(
455            BufferMut::from_iter(iter).freeze(),
456            decimal_dtype,
457            Validity::NonNullable,
458        )
459    }
460
461    /// Creates a new [`DecimalArray`] from an iterator of optional values.
462    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
463        iter: I,
464        decimal_dtype: DecimalDType,
465    ) -> Self {
466        let iter = iter.into_iter();
467        let mut values = BufferMut::with_capacity(iter.size_hint().0);
468        let mut validity = BitBufferMut::with_capacity(values.capacity());
469
470        for value in iter {
471            match value {
472                Some(value) => {
473                    values.push(value);
474                    validity.append(true);
475                }
476                None => {
477                    values.push(T::default());
478                    validity.append(false);
479                }
480            }
481        }
482
483        Self::new(
484            values.freeze(),
485            decimal_dtype,
486            Validity::from(validity.freeze()),
487        )
488    }
489
490    /// Creates a new [`DecimalArray`] from a [`BufferHandle`].
491    pub fn new_handle(
492        values: BufferHandle,
493        values_type: DecimalType,
494        decimal_dtype: DecimalDType,
495        validity: Validity,
496    ) -> Self {
497        Self::try_new_handle(values, values_type, decimal_dtype, validity)
498            .vortex_expect("DecimalArray construction failed")
499    }
500
501    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] with validation.
502    pub fn try_new_handle(
503        values: BufferHandle,
504        values_type: DecimalType,
505        decimal_dtype: DecimalDType,
506        validity: Validity,
507    ) -> VortexResult<Self> {
508        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
509        let len = values.len() / values_type.byte_width();
510        let slots = DecimalData::make_slots(&validity, len);
511        let data = DecimalData::try_new_handle(values, values_type, decimal_dtype)?;
512        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
513    }
514
515    /// Creates a new [`DecimalArray`] without validation from a [`BufferHandle`].
516    ///
517    /// # Safety
518    ///
519    /// See [`DecimalData::new_unchecked_handle`].
520    pub unsafe fn new_unchecked_handle(
521        values: BufferHandle,
522        values_type: DecimalType,
523        decimal_dtype: DecimalDType,
524        validity: Validity,
525    ) -> Self {
526        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
527        let len = values.len() / values_type.byte_width();
528        let slots = DecimalData::make_slots(&validity, len);
529        let data = unsafe { DecimalData::new_unchecked_handle(values, values_type, decimal_dtype) };
530        unsafe {
531            Array::from_parts_unchecked(
532                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
533            )
534        }
535    }
536
537    #[expect(
538        clippy::cognitive_complexity,
539        reason = "patching depends on both patch and value physical types"
540    )]
541    pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
542        let offset = patches.offset();
543        let dtype = self.dtype().clone();
544        let len = self.len();
545        let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
546        let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
547
548        let patch_validity = patch_values.validity()?;
549        let patched_validity = self.validity()?.patch(
550            self.len(),
551            offset,
552            &patch_indices.clone().into_array(),
553            &patch_validity,
554            ctx,
555        )?;
556        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
557
558        let data = self.into_data();
559        let data = match_each_integer_ptype!(patch_indices.ptype(), |I| {
560            let patch_indices = patch_indices.as_slice::<I>();
561            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
562                let patch_values = patch_values.buffer::<PatchDVT>();
563                match_each_decimal_value_type!(data.values_type(), |ValuesDVT| {
564                    let buffer = data.buffer::<ValuesDVT>().into_mut();
565                    patch_typed(
566                        buffer,
567                        data.decimal_dtype(),
568                        patch_indices,
569                        offset,
570                        patch_values,
571                    )
572                })
573            })
574        });
575        let slots = DecimalData::make_slots(&patched_validity, len);
576        Ok(unsafe {
577            Array::from_parts_unchecked(
578                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
579            )
580        })
581    }
582}
583
584fn patch_typed<I, ValuesDVT, PatchDVT>(
585    mut buffer: BufferMut<ValuesDVT>,
586    decimal_dtype: DecimalDType,
587    patch_indices: &[I],
588    patch_indices_offset: usize,
589    patch_values: Buffer<PatchDVT>,
590) -> DecimalData
591where
592    I: IntegerPType,
593    PatchDVT: NativeDecimalType,
594    ValuesDVT: NativeDecimalType,
595{
596    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
597        vortex_panic!(
598            "patch_typed: {:?} cannot represent every value in {}.",
599            ValuesDVT::DECIMAL_TYPE,
600            decimal_dtype
601        )
602    }
603
604    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
605        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
606            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
607        );
608    }
609
610    DecimalData::new(buffer.freeze(), decimal_dtype)
611}