Skip to main content

vortex_array/arrays/decimal/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use vortex_buffer::Alignment;
9use vortex_buffer::BitBufferMut;
10use vortex_buffer::Buffer;
11use vortex_buffer::BufferMut;
12use vortex_buffer::ByteBuffer;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::array::Array;
22use crate::array::ArrayParts;
23use crate::array::TypedArrayRef;
24use crate::array::child_to_validity;
25use crate::array::validity_to_child;
26use crate::arrays::Decimal;
27use crate::arrays::DecimalArray;
28use crate::arrays::PrimitiveArray;
29use crate::buffer::BufferHandle;
30use crate::dtype::BigCast;
31use crate::dtype::DType;
32use crate::dtype::DecimalDType;
33use crate::dtype::DecimalType;
34use crate::dtype::IntegerPType;
35use crate::dtype::NativeDecimalType;
36use crate::dtype::Nullability;
37use crate::match_each_decimal_value_type;
38use crate::match_each_integer_ptype;
39use crate::patches::Patches;
40use crate::validity::Validity;
41
42/// The validity bitmap indicating which elements are non-null.
43pub(super) const VALIDITY_SLOT: usize = 0;
44pub(super) const NUM_SLOTS: usize = 1;
45pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
46
47/// A decimal array that stores fixed-precision decimal numbers with configurable scale.
48///
49/// This mirrors the Apache Arrow Decimal encoding and provides exact arithmetic for
50/// financial and scientific computations where floating-point precision loss is unacceptable.
51///
52/// ## Storage Format
53///
54/// Decimals are stored as scaled integers in a supported scalar value type.
55///
56/// The precisions supported for each scalar type are:
57/// - **i8**: precision 1-2 digits
58/// - **i16**: precision 3-4 digits
59/// - **i32**: precision 5-9 digits
60/// - **i64**: precision 10-18 digits
61/// - **i128**: precision 19-38 digits
62/// - **i256**: precision 39-76 digits
63///
64/// These are just the maximal ranges for each scalar type, but it is perfectly legal to store
65/// values with precision that does not match this exactly. For example, a valid DecimalArray with
66/// precision=39 may store its values in an `i8` if all of the actual values fit into it.
67///
68/// Similarly, a `DecimalArray` can be built that stores a set of precision=2 values in a
69/// `Buffer<i256>`.
70///
71/// ## Precision and Scale
72///
73/// - **Precision**: Total number of significant digits (1-76, u8 range)
74/// - **Scale**: Number of digits after the decimal point (-128 to 127, i8 range)
75/// - **Value**: `stored_integer / 10^scale`
76///
77/// For example, with precision=5 and scale=2:
78/// - Stored value 12345 represents 123.45
79/// - Range: -999.99 to 999.99
80///
81/// ## Valid Scalar Types
82///
83/// The underlying storage uses these native types based on precision:
84/// - `DecimalType::I8`, `I16`, `I32`, `I64`, `I128`, `I256`
85/// - Type selection is automatic based on the required precision
86///
87/// # Examples
88///
89/// ```
90/// use vortex_array::arrays::DecimalArray;
91/// use vortex_array::dtype::DecimalDType;
92/// use vortex_buffer::{buffer, Buffer};
93/// use vortex_array::validity::Validity;
94///
95/// // Create a decimal array with precision=5, scale=2 (e.g., 123.45)
96/// let decimal_dtype = DecimalDType::new(5, 2);
97/// let values = buffer![12345i32, 67890i32, -12300i32]; // 123.45, 678.90, -123.00
98/// let array = DecimalArray::new(values, decimal_dtype, Validity::NonNullable);
99///
100/// assert_eq!(array.precision(), 5);
101/// assert_eq!(array.scale(), 2);
102/// assert_eq!(array.len(), 3);
103/// ```
104#[derive(Clone, Debug)]
105pub struct DecimalData {
106    pub(super) decimal_dtype: DecimalDType,
107    pub(super) values: BufferHandle,
108    pub(super) values_type: DecimalType,
109}
110
111impl Display for DecimalData {
112    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113        write!(
114            f,
115            "decimal_dtype: {}, values_type: {}",
116            self.decimal_dtype, self.values_type
117        )
118    }
119}
120
121pub struct DecimalDataParts {
122    pub decimal_dtype: DecimalDType,
123    pub values: BufferHandle,
124    pub values_type: DecimalType,
125    pub validity: Validity,
126}
127
128pub trait DecimalArrayExt: TypedArrayRef<Decimal> {
129    fn decimal_dtype(&self) -> DecimalDType {
130        match self.as_ref().dtype() {
131            DType::Decimal(decimal_dtype, _) => *decimal_dtype,
132            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
133        }
134    }
135
136    fn nullability(&self) -> Nullability {
137        match self.as_ref().dtype() {
138            DType::Decimal(_, nullability) => *nullability,
139            _ => unreachable!("DecimalArrayExt requires a decimal dtype"),
140        }
141    }
142
143    fn validity_child(&self) -> Option<&ArrayRef> {
144        self.as_ref().slots()[VALIDITY_SLOT].as_ref()
145    }
146
147    fn validity(&self) -> Validity {
148        child_to_validity(&self.as_ref().slots()[VALIDITY_SLOT], self.nullability())
149    }
150
151    fn values_type(&self) -> DecimalType {
152        self.values_type
153    }
154
155    fn precision(&self) -> u8 {
156        self.decimal_dtype().precision()
157    }
158
159    fn scale(&self) -> i8 {
160        self.decimal_dtype().scale()
161    }
162
163    fn buffer_handle(&self) -> &BufferHandle {
164        &self.values
165    }
166
167    fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
168        DecimalData::buffer::<T>(self)
169    }
170}
171impl<T: TypedArrayRef<Decimal>> DecimalArrayExt for T {}
172
173impl DecimalData {
174    /// Build the slots vector for this array.
175    pub(super) fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
176        vec![validity_to_child(validity, len)]
177    }
178
179    /// Creates a new [`DecimalArray`] using a host-native buffer.
180    ///
181    /// # Panics
182    ///
183    /// Panics if the provided components do not satisfy the invariants documented in
184    /// [`DecimalArray::new_unchecked`].
185    pub fn new<T: NativeDecimalType>(buffer: Buffer<T>, decimal_dtype: DecimalDType) -> Self {
186        Self::try_new(buffer, decimal_dtype).vortex_expect("DecimalArray construction failed")
187    }
188
189    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] of values that may live in
190    /// host or device memory.
191    ///
192    /// # Panics
193    ///
194    /// Panics if the provided components do not satisfy the invariants documented in
195    /// [`DecimalArray::new_unchecked`].
196    pub fn new_handle(
197        values: BufferHandle,
198        values_type: DecimalType,
199        decimal_dtype: DecimalDType,
200    ) -> Self {
201        Self::try_new_handle(values, values_type, decimal_dtype)
202            .vortex_expect("DecimalArray construction failed")
203    }
204
205    /// Constructs a new `DecimalArray`.
206    ///
207    /// See [`DecimalArray::new_unchecked`] for more information.
208    ///
209    /// # Errors
210    ///
211    /// Returns an error if the provided components do not satisfy the invariants documented in
212    /// [`DecimalArray::new_unchecked`].
213    pub fn try_new<T: NativeDecimalType>(
214        buffer: Buffer<T>,
215        decimal_dtype: DecimalDType,
216    ) -> VortexResult<Self> {
217        let values = BufferHandle::new_host(buffer.into_byte_buffer());
218        let values_type = T::DECIMAL_TYPE;
219
220        Self::try_new_handle(values, values_type, decimal_dtype)
221    }
222
223    /// Constructs a new `DecimalArray` with validation from a [`BufferHandle`].
224    ///
225    /// This pathway allows building new decimal arrays that may come from host or device memory.
226    ///
227    /// # Errors
228    ///
229    /// See [`DecimalArray::new_unchecked`] for invariants that are checked.
230    pub fn try_new_handle(
231        values: BufferHandle,
232        values_type: DecimalType,
233        decimal_dtype: DecimalDType,
234    ) -> VortexResult<Self> {
235        Self::validate(&values, values_type)?;
236
237        // SAFETY: validate ensures all invariants are met.
238        Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype) })
239    }
240
241    /// Creates a new [`DecimalArray`] without validation from these components:
242    ///
243    /// * `buffer` is a typed buffer containing the decimal values.
244    /// * `decimal_dtype` specifies the decimal precision and scale.
245    /// * `validity` holds the null values.
246    ///
247    /// # Safety
248    ///
249    /// The caller must ensure all of the following invariants are satisfied:
250    ///
251    /// - All non-null values in `buffer` must be representable within the specified precision.
252    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
253    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
254    pub unsafe fn new_unchecked<T: NativeDecimalType>(
255        buffer: Buffer<T>,
256        decimal_dtype: DecimalDType,
257    ) -> Self {
258        // SAFETY: new_unchecked_handle inherits the safety guarantees of new_unchecked
259        unsafe {
260            Self::new_unchecked_handle(
261                BufferHandle::new_host(buffer.into_byte_buffer()),
262                T::DECIMAL_TYPE,
263                decimal_dtype,
264            )
265        }
266    }
267
268    /// Create a new array with decimal values backed by the given buffer handle.
269    ///
270    /// # Safety
271    ///
272    /// The caller must ensure all of the following invariants are satisfied:
273    ///
274    /// - All non-null values in `values` must be representable within the specified precision.
275    /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
276    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
277    pub unsafe fn new_unchecked_handle(
278        values: BufferHandle,
279        values_type: DecimalType,
280        decimal_dtype: DecimalDType,
281    ) -> Self {
282        Self {
283            decimal_dtype,
284            values,
285            values_type,
286        }
287    }
288
289    /// Validates the components that would be used to create a [`DecimalArray`] from a byte buffer.
290    ///
291    /// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
292    fn validate(buffer: &BufferHandle, values_type: DecimalType) -> VortexResult<()> {
293        let byte_width = values_type.byte_width();
294        vortex_ensure!(
295            buffer.len().is_multiple_of(byte_width),
296            InvalidArgument: "decimal buffer size {} is not divisible by element width {}",
297            buffer.len(),
298            byte_width,
299        );
300        match_each_decimal_value_type!(values_type, |D| {
301            vortex_ensure!(
302                buffer.is_aligned_to(Alignment::of::<D>()),
303                InvalidArgument: "decimal buffer alignment {:?} is invalid for values type {:?}",
304                buffer.alignment(),
305                D::DECIMAL_TYPE,
306            );
307            Ok::<(), vortex_error::VortexError>(())
308        })?;
309        Ok(())
310    }
311
312    /// Creates a new [`DecimalArray`] from a raw byte buffer without validation.
313    ///
314    /// # Safety
315    ///
316    /// The caller must ensure:
317    /// - The `byte_buffer` contains valid data for the specified `values_type`
318    /// - The buffer length is compatible with the `values_type` (i.e., divisible by the type size)
319    /// - All non-null values are representable within the specified precision
320    /// - If `validity` is [`Validity::Array`], its length must equal the number of elements
321    pub unsafe fn new_unchecked_from_byte_buffer(
322        byte_buffer: ByteBuffer,
323        values_type: DecimalType,
324        decimal_dtype: DecimalDType,
325    ) -> Self {
326        // SAFETY: inherits the same safety contract as `new_unchecked_from_byte_buffer`
327        unsafe {
328            Self::new_unchecked_handle(
329                BufferHandle::new_host(byte_buffer),
330                values_type,
331                decimal_dtype,
332            )
333        }
334    }
335
336    /// Returns the length of this array.
337    pub fn len(&self) -> usize {
338        self.values.len() / self.values_type.byte_width()
339    }
340
341    /// Returns `true` if this array is empty.
342    pub fn is_empty(&self) -> bool {
343        self.len() == 0
344    }
345
346    /// Returns the underlying [`ByteBuffer`] of the array.
347    pub fn buffer_handle(&self) -> &BufferHandle {
348        &self.values
349    }
350
351    pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
352        if self.values_type != T::DECIMAL_TYPE {
353            vortex_panic!(
354                "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
355                T::DECIMAL_TYPE,
356                self.values_type,
357            );
358        }
359        Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
360    }
361
362    /// Return the `DecimalType` used to represent the values in the array.
363    pub fn values_type(&self) -> DecimalType {
364        self.values_type
365    }
366
367    /// Returns the decimal type information.
368    pub fn decimal_dtype(&self) -> DecimalDType {
369        self.decimal_dtype
370    }
371
372    pub fn precision(&self) -> u8 {
373        self.decimal_dtype.precision()
374    }
375
376    pub fn scale(&self) -> i8 {
377        self.decimal_dtype.scale()
378    }
379}
380
381impl Array<Decimal> {
382    pub fn into_data_parts(self) -> DecimalDataParts {
383        let validity = DecimalArrayExt::validity(&self);
384        let decimal_dtype = DecimalArrayExt::decimal_dtype(&self);
385        let data = self.into_data();
386        DecimalDataParts {
387            decimal_dtype,
388            values: data.values,
389            values_type: data.values_type,
390            validity,
391        }
392    }
393}
394
395impl Array<Decimal> {
396    /// Creates a new [`DecimalArray`] using a host-native buffer.
397    pub fn new<T: NativeDecimalType>(
398        buffer: Buffer<T>,
399        decimal_dtype: DecimalDType,
400        validity: Validity,
401    ) -> Self {
402        Self::try_new(buffer, decimal_dtype, validity)
403            .vortex_expect("DecimalArray construction failed")
404    }
405
406    /// Creates a new [`DecimalArray`] without validation.
407    ///
408    /// # Safety
409    ///
410    /// See [`DecimalData::new_unchecked`].
411    pub unsafe fn new_unchecked<T: NativeDecimalType>(
412        buffer: Buffer<T>,
413        decimal_dtype: DecimalDType,
414        validity: Validity,
415    ) -> Self {
416        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
417        let len = buffer.len();
418        let slots = DecimalData::make_slots(&validity, len);
419        let data = unsafe { DecimalData::new_unchecked(buffer, decimal_dtype) };
420        unsafe {
421            Array::from_parts_unchecked(
422                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
423            )
424        }
425    }
426
427    /// Creates a new [`DecimalArray`] from a host-native buffer with validation.
428    pub fn try_new<T: NativeDecimalType>(
429        buffer: Buffer<T>,
430        decimal_dtype: DecimalDType,
431        validity: Validity,
432    ) -> VortexResult<Self> {
433        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
434        let len = buffer.len();
435        let slots = DecimalData::make_slots(&validity, len);
436        let data = DecimalData::try_new(buffer, decimal_dtype)?;
437        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
438    }
439
440    /// Creates a new [`DecimalArray`] from an iterator of values.
441    #[expect(
442        clippy::same_name_method,
443        reason = "intentionally named from_iter like Iterator::from_iter"
444    )]
445    pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
446        iter: I,
447        decimal_dtype: DecimalDType,
448    ) -> Self {
449        Self::new(
450            BufferMut::from_iter(iter).freeze(),
451            decimal_dtype,
452            Validity::NonNullable,
453        )
454    }
455
456    /// Creates a new [`DecimalArray`] from an iterator of optional values.
457    pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
458        iter: I,
459        decimal_dtype: DecimalDType,
460    ) -> Self {
461        let iter = iter.into_iter();
462        let mut values = BufferMut::with_capacity(iter.size_hint().0);
463        let mut validity = BitBufferMut::with_capacity(values.capacity());
464
465        for value in iter {
466            match value {
467                Some(value) => {
468                    values.push(value);
469                    validity.append(true);
470                }
471                None => {
472                    values.push(T::default());
473                    validity.append(false);
474                }
475            }
476        }
477
478        Self::new(
479            values.freeze(),
480            decimal_dtype,
481            Validity::from(validity.freeze()),
482        )
483    }
484
485    /// Creates a new [`DecimalArray`] from a [`BufferHandle`].
486    pub fn new_handle(
487        values: BufferHandle,
488        values_type: DecimalType,
489        decimal_dtype: DecimalDType,
490        validity: Validity,
491    ) -> Self {
492        Self::try_new_handle(values, values_type, decimal_dtype, validity)
493            .vortex_expect("DecimalArray construction failed")
494    }
495
496    /// Creates a new [`DecimalArray`] from a [`BufferHandle`] with validation.
497    pub fn try_new_handle(
498        values: BufferHandle,
499        values_type: DecimalType,
500        decimal_dtype: DecimalDType,
501        validity: Validity,
502    ) -> VortexResult<Self> {
503        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
504        let len = values.len() / values_type.byte_width();
505        let slots = DecimalData::make_slots(&validity, len);
506        let data = DecimalData::try_new_handle(values, values_type, decimal_dtype)?;
507        Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
508    }
509
510    /// Creates a new [`DecimalArray`] without validation from a [`BufferHandle`].
511    ///
512    /// # Safety
513    ///
514    /// See [`DecimalData::new_unchecked_handle`].
515    pub unsafe fn new_unchecked_handle(
516        values: BufferHandle,
517        values_type: DecimalType,
518        decimal_dtype: DecimalDType,
519        validity: Validity,
520    ) -> Self {
521        let dtype = DType::Decimal(decimal_dtype, validity.nullability());
522        let len = values.len() / values_type.byte_width();
523        let slots = DecimalData::make_slots(&validity, len);
524        let data = unsafe { DecimalData::new_unchecked_handle(values, values_type, decimal_dtype) };
525        unsafe {
526            Array::from_parts_unchecked(
527                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
528            )
529        }
530    }
531
532    #[allow(
533        clippy::cognitive_complexity,
534        reason = "patching depends on both patch and value physical types"
535    )]
536    pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
537        let offset = patches.offset();
538        let dtype = self.dtype().clone();
539        let len = self.len();
540        let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
541        let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
542
543        let patch_validity = patch_values.validity()?;
544        let patched_validity = self.validity()?.patch(
545            self.len(),
546            offset,
547            &patch_indices.clone().into_array(),
548            &patch_validity,
549            ctx,
550        )?;
551        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
552
553        let data = self.into_data();
554        let data = match_each_integer_ptype!(patch_indices.ptype(), |I| {
555            let patch_indices = patch_indices.as_slice::<I>();
556            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
557                let patch_values = patch_values.buffer::<PatchDVT>();
558                match_each_decimal_value_type!(data.values_type(), |ValuesDVT| {
559                    let buffer = data.buffer::<ValuesDVT>().into_mut();
560                    patch_typed(
561                        buffer,
562                        data.decimal_dtype(),
563                        patch_indices,
564                        offset,
565                        patch_values,
566                    )
567                })
568            })
569        });
570        let slots = DecimalData::make_slots(&patched_validity, len);
571        Ok(unsafe {
572            Array::from_parts_unchecked(
573                ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
574            )
575        })
576    }
577}
578
579fn patch_typed<I, ValuesDVT, PatchDVT>(
580    mut buffer: BufferMut<ValuesDVT>,
581    decimal_dtype: DecimalDType,
582    patch_indices: &[I],
583    patch_indices_offset: usize,
584    patch_values: Buffer<PatchDVT>,
585) -> DecimalData
586where
587    I: IntegerPType,
588    PatchDVT: NativeDecimalType,
589    ValuesDVT: NativeDecimalType,
590{
591    if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
592        vortex_panic!(
593            "patch_typed: {:?} cannot represent every value in {}.",
594            ValuesDVT::DECIMAL_TYPE,
595            decimal_dtype
596        )
597    }
598
599    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
600        buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
601            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
602        );
603    }
604
605    DecimalData::new(buffer.freeze(), decimal_dtype)
606}