vortex_array/builders/
decimal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use vortex_buffer::{Buffer, BufferMut};
7use vortex_dtype::{DType, DecimalDType, Nullability};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
9use vortex_mask::Mask;
10use vortex_scalar::{BigCast, NativeDecimalType, i256, match_each_decimal_value_type};
11
12use crate::arrays::{BoolArray, DecimalArray};
13use crate::builders::ArrayBuilder;
14use crate::builders::lazy_validity_builder::LazyNullBufferBuilder;
15use crate::validity::Validity;
16use crate::{Array, ArrayRef, IntoArray, ToCanonical};
17
18/// Wrapper around the typed builder.
19///
20/// We want to be able to downcast a `Box<dyn ArrayBuilder>` to a [`DecimalBuilder`] and we generally
21/// don't have enough type information to get the `T` at the call site, so we instead use this
22/// to hold values and can push values into the correct buffer type generically.
23enum DecimalBuffer {
24    I8(BufferMut<i8>),
25    I16(BufferMut<i16>),
26    I32(BufferMut<i32>),
27    I64(BufferMut<i64>),
28    I128(BufferMut<i128>),
29    I256(BufferMut<i256>),
30}
31
32impl Default for DecimalBuffer {
33    fn default() -> Self {
34        Self::I8(BufferMut::<i8>::empty())
35    }
36}
37
38macro_rules! impl_from_buffer {
39    ($typ:ty, $variant:ident) => {
40        impl From<BufferMut<$typ>> for DecimalBuffer {
41            fn from(buffer: BufferMut<$typ>) -> Self {
42                Self::$variant(buffer)
43            }
44        }
45    };
46}
47impl_from_buffer!(i8, I8);
48impl_from_buffer!(i16, I16);
49impl_from_buffer!(i32, I32);
50impl_from_buffer!(i64, I64);
51impl_from_buffer!(i128, I128);
52impl_from_buffer!(i256, I256);
53
54macro_rules! delegate_fn {
55    ($self:expr, | $tname:ident, $buffer:ident | $body:block) => {{
56        #[allow(unused)]
57        match $self {
58            DecimalBuffer::I8(buffer) => {
59                type $tname = i8;
60                let $buffer = buffer;
61                $body
62            }
63            DecimalBuffer::I16(buffer) => {
64                type $tname = i16;
65                let $buffer = buffer;
66                $body
67            }
68            DecimalBuffer::I32(buffer) => {
69                type $tname = i32;
70                let $buffer = buffer;
71                $body
72            }
73            DecimalBuffer::I64(buffer) => {
74                type $tname = i64;
75                let $buffer = buffer;
76                $body
77            }
78            DecimalBuffer::I128(buffer) => {
79                type $tname = i128;
80                let $buffer = buffer;
81                $body
82            }
83            DecimalBuffer::I256(buffer) => {
84                type $tname = i256;
85                let $buffer = buffer;
86                $body
87            }
88        }
89    }};
90}
91
92impl DecimalBuffer {
93    fn push<V: NativeDecimalType>(&mut self, value: V) {
94        delegate_fn!(self, |T, buffer| {
95            buffer.push(<T as BigCast>::from(value).vortex_expect("decimal conversion failure"))
96        });
97    }
98
99    fn push_n<V: NativeDecimalType>(&mut self, value: V, n: usize) {
100        delegate_fn!(self, |T, buffer| {
101            buffer.push_n(
102                <T as BigCast>::from(value).vortex_expect("decimal conversion failure"),
103                n,
104            )
105        });
106    }
107
108    fn reserve(&mut self, additional: usize) {
109        delegate_fn!(self, |T, buffer| { buffer.reserve(additional) })
110    }
111
112    fn capacity(&self) -> usize {
113        delegate_fn!(self, |T, buffer| { buffer.capacity() })
114    }
115
116    fn len(&self) -> usize {
117        delegate_fn!(self, |T, buffer| { buffer.len() })
118    }
119
120    pub fn extend<I, V: NativeDecimalType>(&mut self, iter: I)
121    where
122        I: Iterator<Item = V>,
123    {
124        delegate_fn!(self, |T, buffer| {
125            buffer.extend(
126                iter.map(|x| <T as BigCast>::from(x).vortex_expect("decimal conversion failure")),
127            )
128        })
129    }
130}
131
132/// An [`ArrayBuilder`] for `Decimal` typed arrays.
133///
134/// The output will be a new [`DecimalArray`] holding values of `T`. Any value that is
135/// a valid [decimal type][NativeDecimalType] can be appended to the builder and it will be
136/// immediately coerced into the target type.
137pub struct DecimalBuilder {
138    values: DecimalBuffer,
139    nulls: LazyNullBufferBuilder,
140    dtype: DType,
141}
142
143const DEFAULT_BUILDER_CAPACITY: usize = 1024;
144
145impl DecimalBuilder {
146    pub fn new<T: NativeDecimalType>(precision: u8, scale: i8, nullability: Nullability) -> Self {
147        Self::with_capacity::<T>(
148            DEFAULT_BUILDER_CAPACITY,
149            DecimalDType::new(precision, scale),
150            nullability,
151        )
152    }
153
154    pub fn with_capacity<T: NativeDecimalType>(
155        capacity: usize,
156        decimal: DecimalDType,
157        nullability: Nullability,
158    ) -> Self {
159        Self {
160            values: match_each_decimal_value_type!(T::VALUES_TYPE, |D| {
161                DecimalBuffer::from(BufferMut::<D>::with_capacity(capacity))
162            }),
163            nulls: LazyNullBufferBuilder::new(capacity),
164            dtype: DType::Decimal(decimal, nullability),
165        }
166    }
167}
168
169impl DecimalBuilder {
170    fn extend_with_validity_mask(&mut self, validity_mask: Mask) {
171        self.nulls.append_validity_mask(validity_mask);
172    }
173
174    /// Extend the values buffer from another buffer of type V where V can be coerced
175    /// to the builder type.
176    fn extend_from_buffer<V: NativeDecimalType>(&mut self, values: &Buffer<V>) {
177        self.values.extend(values.iter().copied());
178    }
179}
180
181impl DecimalBuilder {
182    pub fn append_value<V: NativeDecimalType>(&mut self, value: V) {
183        self.values.push(value);
184        self.nulls.append(true);
185    }
186
187    pub fn append_option<V: NativeDecimalType>(&mut self, value: Option<V>) {
188        match value {
189            Some(value) => {
190                self.values.push(value);
191                self.nulls.append(true);
192            }
193            None => self.append_null(),
194        }
195    }
196
197    /// Append a `Mask` to the null buffer.
198    pub fn append_mask(&mut self, mask: Mask) {
199        self.nulls.append_validity_mask(mask);
200    }
201}
202
203impl DecimalBuilder {
204    pub fn finish_into_decimal(&mut self) -> DecimalArray {
205        let nulls = self.nulls.finish();
206
207        if let Some(null_buf) = nulls.as_ref() {
208            assert_eq!(
209                null_buf.len(),
210                self.values.len(),
211                "null buffer length must equal value buffer length"
212            );
213        }
214
215        let validity = match (nulls, self.dtype.nullability()) {
216            (None, Nullability::NonNullable) => Validity::NonNullable,
217            (Some(_), Nullability::NonNullable) => {
218                vortex_panic!("Non-nullable builder has null values")
219            }
220            (None, Nullability::Nullable) => Validity::AllValid,
221            (Some(nulls), Nullability::Nullable) => {
222                if nulls.null_count() == nulls.len() {
223                    Validity::AllInvalid
224                } else {
225                    Validity::Array(BoolArray::from(nulls.into_inner()).into_array())
226                }
227            }
228        };
229
230        let DType::Decimal(decimal_dtype, _) = self.dtype else {
231            vortex_panic!("DecimalBuilder must have Decimal DType");
232        };
233
234        delegate_fn!(std::mem::take(&mut self.values), |T, values| {
235            DecimalArray::new::<T>(values.freeze(), decimal_dtype, validity)
236        })
237    }
238}
239
240impl ArrayBuilder for DecimalBuilder {
241    fn as_any(&self) -> &dyn Any {
242        self
243    }
244
245    fn as_any_mut(&mut self) -> &mut dyn Any {
246        self
247    }
248
249    fn dtype(&self) -> &DType {
250        &self.dtype
251    }
252
253    fn len(&self) -> usize {
254        self.values.len()
255    }
256
257    fn append_zeros(&mut self, n: usize) {
258        self.values.push_n(0, n);
259        self.nulls.append_n_non_nulls(n);
260    }
261
262    fn append_nulls(&mut self, n: usize) {
263        self.values.push_n(0, n);
264        self.nulls.append_n_nulls(n);
265    }
266
267    fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()> {
268        let array = array.to_decimal()?;
269
270        let DType::Decimal(decimal_dtype, _) = self.dtype else {
271            vortex_panic!("DecimalBuilder must have Decimal DType");
272        };
273
274        if array.decimal_dtype() != decimal_dtype {
275            vortex_bail!(
276                "Cannot extend from array with different decimal type: {:?} != {:?}",
277                array.decimal_dtype(),
278                decimal_dtype
279            );
280        }
281
282        match_each_decimal_value_type!(array.values_type(), |D| {
283            self.extend_from_buffer(&array.buffer::<D>())
284        });
285
286        self.extend_with_validity_mask(array.validity_mask()?);
287
288        Ok(())
289    }
290
291    fn ensure_capacity(&mut self, capacity: usize) {
292        if capacity > self.values.capacity() {
293            self.values.reserve(capacity - self.values.len());
294            self.nulls.ensure_capacity(capacity);
295        }
296    }
297
298    fn set_validity(&mut self, validity: Mask) {
299        self.nulls = LazyNullBufferBuilder::new(validity.len());
300        self.nulls.append_validity_mask(validity);
301    }
302
303    fn finish(&mut self) -> ArrayRef {
304        self.finish_into_decimal().into_array()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use crate::builders::{ArrayBuilder, DecimalBuilder};
311
312    #[test]
313    fn test_mixed_extend() {
314        let mut i8s = DecimalBuilder::new::<i8>(2, 1, false.into());
315        i8s.append_value(10);
316        i8s.append_value(11);
317        i8s.append_value(12);
318        let i8s = i8s.finish();
319
320        let mut i128s = DecimalBuilder::new::<i128>(2, 1, false.into());
321        i128s.extend_from_array(&i8s).unwrap();
322        let i128s = i128s.finish_into_decimal();
323        assert_eq!(i128s.buffer::<i128>().as_slice(), &[10, 11, 12]);
324    }
325}