vortex_array/arrays/decimal/
array.rs1use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::BigCast;
10use vortex_dtype::DType;
11use vortex_dtype::DecimalDType;
12use vortex_dtype::DecimalType;
13use vortex_dtype::IntegerPType;
14use vortex_dtype::NativeDecimalType;
15use vortex_dtype::match_each_decimal_value_type;
16use vortex_dtype::match_each_integer_ptype;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_ensure;
20use vortex_error::vortex_panic;
21
22use crate::ToCanonical;
23use crate::patches::Patches;
24use crate::stats::ArrayStats;
25use crate::validity::Validity;
26use crate::vtable::ValidityHelper;
27
28#[derive(Clone, Debug)]
86pub struct DecimalArray {
87 pub(super) dtype: DType,
88 pub(super) values: ByteBuffer,
89 pub(super) values_type: DecimalType,
90 pub(super) validity: Validity,
91 pub(super) stats_set: ArrayStats,
92}
93
94impl DecimalArray {
95 pub fn new<T: NativeDecimalType>(
102 buffer: Buffer<T>,
103 decimal_dtype: DecimalDType,
104 validity: Validity,
105 ) -> Self {
106 Self::try_new(buffer, decimal_dtype, validity)
107 .vortex_expect("DecimalArray construction failed")
108 }
109
110 pub fn try_new<T: NativeDecimalType>(
119 buffer: Buffer<T>,
120 decimal_dtype: DecimalDType,
121 validity: Validity,
122 ) -> VortexResult<Self> {
123 Self::validate(&buffer, &validity)?;
124
125 Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) })
127 }
128
129 pub unsafe fn new_unchecked<T: NativeDecimalType>(
143 buffer: Buffer<T>,
144 decimal_dtype: DecimalDType,
145 validity: Validity,
146 ) -> Self {
147 #[cfg(debug_assertions)]
148 Self::validate(&buffer, &validity)
149 .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
150
151 Self {
152 values: buffer.into_byte_buffer(),
153 values_type: T::DECIMAL_TYPE,
154 dtype: DType::Decimal(decimal_dtype, validity.nullability()),
155 validity,
156 stats_set: Default::default(),
157 }
158 }
159
160 pub fn validate<T: NativeDecimalType>(
164 buffer: &Buffer<T>,
165 validity: &Validity,
166 ) -> VortexResult<()> {
167 if let Some(len) = validity.maybe_len() {
168 vortex_ensure!(
169 buffer.len() == len,
170 "Buffer and validity length mismatch: buffer={}, validity={}",
171 buffer.len(),
172 len,
173 );
174 }
175
176 Ok(())
177 }
178
179 pub fn byte_buffer(&self) -> ByteBuffer {
181 self.values.clone()
182 }
183
184 pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
185 if self.values_type != T::DECIMAL_TYPE {
186 vortex_panic!(
187 "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
188 T::DECIMAL_TYPE,
189 self.values_type,
190 );
191 }
192 Buffer::<T>::from_byte_buffer(self.values.clone())
193 }
194
195 pub fn decimal_dtype(&self) -> DecimalDType {
197 if let DType::Decimal(decimal_dtype, _) = self.dtype {
198 decimal_dtype
199 } else {
200 vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
201 }
202 }
203
204 pub fn values_type(&self) -> DecimalType {
205 self.values_type
206 }
207
208 pub fn precision(&self) -> u8 {
209 self.decimal_dtype().precision()
210 }
211
212 pub fn scale(&self) -> i8 {
213 self.decimal_dtype().scale()
214 }
215
216 pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
217 iter: I,
218 decimal_dtype: DecimalDType,
219 ) -> Self {
220 let iter = iter.into_iter();
221
222 Self::new(
223 BufferMut::from_iter(iter).freeze(),
224 decimal_dtype,
225 Validity::NonNullable,
226 )
227 }
228
229 pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
230 iter: I,
231 decimal_dtype: DecimalDType,
232 ) -> Self {
233 let iter = iter.into_iter();
234 let mut values = BufferMut::with_capacity(iter.size_hint().0);
235 let mut validity = BitBufferMut::with_capacity(values.capacity());
236
237 for i in iter {
238 match i {
239 None => {
240 validity.append(false);
241 values.push(T::default());
242 }
243 Some(e) => {
244 validity.append(true);
245 values.push(e);
246 }
247 }
248 }
249 Self::new(
250 values.freeze(),
251 decimal_dtype,
252 Validity::from(validity.freeze()),
253 )
254 }
255
256 #[expect(
257 clippy::cognitive_complexity,
258 reason = "complexity from nested match_each_* macros"
259 )]
260 pub fn patch(self, patches: &Patches) -> Self {
261 let offset = patches.offset();
262 let patch_indices = patches.indices().to_primitive();
263 let patch_values = patches.values().to_decimal();
264
265 let patched_validity = self.validity().clone().patch(
266 self.len(),
267 offset,
268 patch_indices.as_ref(),
269 patch_values.validity(),
270 );
271 assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
272
273 match_each_integer_ptype!(patch_indices.ptype(), |I| {
274 let patch_indices = patch_indices.as_slice::<I>();
275 match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
276 let patch_values = patch_values.buffer::<PatchDVT>();
277 match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
278 let buffer = self.buffer::<ValuesDVT>().into_mut();
279 patch_typed(
280 buffer,
281 self.decimal_dtype(),
282 patch_indices,
283 offset,
284 patch_values,
285 patched_validity,
286 )
287 })
288 })
289 })
290 }
291}
292
293fn patch_typed<I, ValuesDVT, PatchDVT>(
294 mut buffer: BufferMut<ValuesDVT>,
295 decimal_dtype: DecimalDType,
296 patch_indices: &[I],
297 patch_indices_offset: usize,
298 patch_values: Buffer<PatchDVT>,
299 patched_validity: Validity,
300) -> DecimalArray
301where
302 I: IntegerPType,
303 PatchDVT: NativeDecimalType,
304 ValuesDVT: NativeDecimalType,
305{
306 if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
307 vortex_panic!(
308 "patch_typed: {:?} cannot represent every value in {}.",
309 ValuesDVT::DECIMAL_TYPE,
310 decimal_dtype
311 )
312 }
313
314 for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
315 buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
316 "values of a given DecimalDType are representable in all compatible NativeDecimalType",
317 );
318 }
319
320 DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
321}