vortex_array/arrays/decimal/
array.rs1use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13
14use crate::ExecutionCtx;
15use crate::arrays::PrimitiveArray;
16use crate::buffer::BufferHandle;
17use crate::dtype::BigCast;
18use crate::dtype::DType;
19use crate::dtype::DecimalDType;
20use crate::dtype::DecimalType;
21use crate::dtype::IntegerPType;
22use crate::dtype::NativeDecimalType;
23use crate::match_each_decimal_value_type;
24use crate::match_each_integer_ptype;
25use crate::patches::Patches;
26use crate::stats::ArrayStats;
27use crate::validity::Validity;
28use crate::vtable::ValidityHelper;
29
30#[derive(Clone, Debug)]
88pub struct DecimalArray {
89 pub(super) dtype: DType,
90 pub(super) values: BufferHandle,
91 pub(super) values_type: DecimalType,
92 pub(super) validity: Validity,
93 pub(super) stats_set: ArrayStats,
94}
95
96pub struct DecimalArrayParts {
97 pub decimal_dtype: DecimalDType,
98 pub values: BufferHandle,
99 pub values_type: DecimalType,
100 pub validity: Validity,
101}
102
103impl DecimalArray {
104 pub fn new<T: NativeDecimalType>(
111 buffer: Buffer<T>,
112 decimal_dtype: DecimalDType,
113 validity: Validity,
114 ) -> Self {
115 Self::try_new(buffer, decimal_dtype, validity)
116 .vortex_expect("DecimalArray construction failed")
117 }
118
119 pub fn new_handle(
127 values: BufferHandle,
128 values_type: DecimalType,
129 decimal_dtype: DecimalDType,
130 validity: Validity,
131 ) -> Self {
132 Self::try_new_handle(values, values_type, decimal_dtype, validity)
133 .vortex_expect("DecimalArray construction failed")
134 }
135
136 pub fn try_new<T: NativeDecimalType>(
145 buffer: Buffer<T>,
146 decimal_dtype: DecimalDType,
147 validity: Validity,
148 ) -> VortexResult<Self> {
149 let values = BufferHandle::new_host(buffer.into_byte_buffer());
150 let values_type = T::DECIMAL_TYPE;
151
152 Self::try_new_handle(values, values_type, decimal_dtype, validity)
153 }
154
155 pub fn try_new_handle(
163 values: BufferHandle,
164 values_type: DecimalType,
165 decimal_dtype: DecimalDType,
166 validity: Validity,
167 ) -> VortexResult<Self> {
168 Self::validate(&values, values_type, &validity)?;
169
170 Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype, validity) })
172 }
173
174 pub unsafe fn new_unchecked<T: NativeDecimalType>(
188 buffer: Buffer<T>,
189 decimal_dtype: DecimalDType,
190 validity: Validity,
191 ) -> Self {
192 unsafe {
194 Self::new_unchecked_handle(
195 BufferHandle::new_host(buffer.into_byte_buffer()),
196 T::DECIMAL_TYPE,
197 decimal_dtype,
198 validity,
199 )
200 }
201 }
202
203 pub unsafe fn new_unchecked_handle(
213 values: BufferHandle,
214 values_type: DecimalType,
215 decimal_dtype: DecimalDType,
216 validity: Validity,
217 ) -> Self {
218 #[cfg(debug_assertions)]
219 {
220 Self::validate(&values, values_type, &validity)
221 .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
222 }
223
224 Self {
225 values,
226 values_type,
227 dtype: DType::Decimal(decimal_dtype, validity.nullability()),
228 validity,
229 stats_set: Default::default(),
230 }
231 }
232
233 fn validate(
237 buffer: &BufferHandle,
238 values_type: DecimalType,
239 validity: &Validity,
240 ) -> VortexResult<()> {
241 if let Some(validity_len) = validity.maybe_len() {
242 let expected_len = values_type.byte_width() * validity_len;
243 vortex_ensure!(
244 buffer.len() == expected_len,
245 InvalidArgument: "expected buffer of size {} bytes, was {} bytes",
246 expected_len,
247 buffer.len(),
248 );
249 }
250
251 Ok(())
252 }
253
254 pub unsafe fn new_unchecked_from_byte_buffer(
264 byte_buffer: ByteBuffer,
265 values_type: DecimalType,
266 decimal_dtype: DecimalDType,
267 validity: Validity,
268 ) -> Self {
269 unsafe {
271 Self::new_unchecked_handle(
272 BufferHandle::new_host(byte_buffer),
273 values_type,
274 decimal_dtype,
275 validity,
276 )
277 }
278 }
279
280 pub fn into_parts(self) -> DecimalArrayParts {
281 let decimal_dtype = self.dtype.into_decimal_opt().vortex_expect("cannot fail");
282
283 DecimalArrayParts {
284 decimal_dtype,
285 values: self.values,
286 values_type: self.values_type,
287 validity: self.validity,
288 }
289 }
290
291 pub fn buffer_handle(&self) -> &BufferHandle {
293 &self.values
294 }
295
296 pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
297 if self.values_type != T::DECIMAL_TYPE {
298 vortex_panic!(
299 "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
300 T::DECIMAL_TYPE,
301 self.values_type,
302 );
303 }
304 Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
305 }
306
307 pub fn decimal_dtype(&self) -> DecimalDType {
309 if let DType::Decimal(decimal_dtype, _) = self.dtype {
310 decimal_dtype
311 } else {
312 vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
313 }
314 }
315
316 pub fn values_type(&self) -> DecimalType {
318 self.values_type
319 }
320
321 pub fn precision(&self) -> u8 {
322 self.decimal_dtype().precision()
323 }
324
325 pub fn scale(&self) -> i8 {
326 self.decimal_dtype().scale()
327 }
328
329 pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
330 iter: I,
331 decimal_dtype: DecimalDType,
332 ) -> Self {
333 let iter = iter.into_iter();
334
335 Self::new(
336 BufferMut::from_iter(iter).freeze(),
337 decimal_dtype,
338 Validity::NonNullable,
339 )
340 }
341
342 pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
343 iter: I,
344 decimal_dtype: DecimalDType,
345 ) -> Self {
346 let iter = iter.into_iter();
347 let mut values = BufferMut::with_capacity(iter.size_hint().0);
348 let mut validity = BitBufferMut::with_capacity(values.capacity());
349
350 for i in iter {
351 match i {
352 None => {
353 validity.append(false);
354 values.push(T::default());
355 }
356 Some(e) => {
357 validity.append(true);
358 values.push(e);
359 }
360 }
361 }
362 Self::new(
363 values.freeze(),
364 decimal_dtype,
365 Validity::from(validity.freeze()),
366 )
367 }
368
369 #[expect(
370 clippy::cognitive_complexity,
371 reason = "complexity from nested match_each_* macros"
372 )]
373 pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
374 let offset = patches.offset();
375 let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
376 let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
377
378 let patched_validity = self.validity().clone().patch(
379 self.len(),
380 offset,
381 &patch_indices.to_array(),
382 patch_values.validity(),
383 ctx,
384 )?;
385 assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
386
387 Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
388 let patch_indices = patch_indices.as_slice::<I>();
389 match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
390 let patch_values = patch_values.buffer::<PatchDVT>();
391 match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
392 let buffer = self.buffer::<ValuesDVT>().into_mut();
393 patch_typed(
394 buffer,
395 self.decimal_dtype(),
396 patch_indices,
397 offset,
398 patch_values,
399 patched_validity,
400 )
401 })
402 })
403 }))
404 }
405}
406
407fn patch_typed<I, ValuesDVT, PatchDVT>(
408 mut buffer: BufferMut<ValuesDVT>,
409 decimal_dtype: DecimalDType,
410 patch_indices: &[I],
411 patch_indices_offset: usize,
412 patch_values: Buffer<PatchDVT>,
413 patched_validity: Validity,
414) -> DecimalArray
415where
416 I: IntegerPType,
417 PatchDVT: NativeDecimalType,
418 ValuesDVT: NativeDecimalType,
419{
420 if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
421 vortex_panic!(
422 "patch_typed: {:?} cannot represent every value in {}.",
423 ValuesDVT::DECIMAL_TYPE,
424 decimal_dtype
425 )
426 }
427
428 for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
429 buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
430 "values of a given DecimalDType are representable in all compatible NativeDecimalType",
431 );
432 }
433
434 DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
435}