vortex_array/arrays/decimal/
array.rs1use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::arrays::PrimitiveArray;
17use crate::buffer::BufferHandle;
18use crate::dtype::BigCast;
19use crate::dtype::DType;
20use crate::dtype::DecimalDType;
21use crate::dtype::DecimalType;
22use crate::dtype::IntegerPType;
23use crate::dtype::NativeDecimalType;
24use crate::match_each_decimal_value_type;
25use crate::match_each_integer_ptype;
26use crate::patches::Patches;
27use crate::stats::ArrayStats;
28use crate::validity::Validity;
29use crate::vtable::ValidityHelper;
30
31#[derive(Clone, Debug)]
89pub struct DecimalArray {
90 pub(super) dtype: DType,
91 pub(super) values: BufferHandle,
92 pub(super) values_type: DecimalType,
93 pub(super) validity: Validity,
94 pub(super) stats_set: ArrayStats,
95}
96
97pub struct DecimalArrayParts {
98 pub decimal_dtype: DecimalDType,
99 pub values: BufferHandle,
100 pub values_type: DecimalType,
101 pub validity: Validity,
102}
103
104impl DecimalArray {
105 pub fn new<T: NativeDecimalType>(
112 buffer: Buffer<T>,
113 decimal_dtype: DecimalDType,
114 validity: Validity,
115 ) -> Self {
116 Self::try_new(buffer, decimal_dtype, validity)
117 .vortex_expect("DecimalArray construction failed")
118 }
119
120 pub fn new_handle(
128 values: BufferHandle,
129 values_type: DecimalType,
130 decimal_dtype: DecimalDType,
131 validity: Validity,
132 ) -> Self {
133 Self::try_new_handle(values, values_type, decimal_dtype, validity)
134 .vortex_expect("DecimalArray construction failed")
135 }
136
137 pub fn try_new<T: NativeDecimalType>(
146 buffer: Buffer<T>,
147 decimal_dtype: DecimalDType,
148 validity: Validity,
149 ) -> VortexResult<Self> {
150 let values = BufferHandle::new_host(buffer.into_byte_buffer());
151 let values_type = T::DECIMAL_TYPE;
152
153 Self::try_new_handle(values, values_type, decimal_dtype, validity)
154 }
155
156 pub fn try_new_handle(
164 values: BufferHandle,
165 values_type: DecimalType,
166 decimal_dtype: DecimalDType,
167 validity: Validity,
168 ) -> VortexResult<Self> {
169 Self::validate(&values, values_type, &validity)?;
170
171 Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype, validity) })
173 }
174
175 pub unsafe fn new_unchecked<T: NativeDecimalType>(
189 buffer: Buffer<T>,
190 decimal_dtype: DecimalDType,
191 validity: Validity,
192 ) -> Self {
193 unsafe {
195 Self::new_unchecked_handle(
196 BufferHandle::new_host(buffer.into_byte_buffer()),
197 T::DECIMAL_TYPE,
198 decimal_dtype,
199 validity,
200 )
201 }
202 }
203
204 pub unsafe fn new_unchecked_handle(
214 values: BufferHandle,
215 values_type: DecimalType,
216 decimal_dtype: DecimalDType,
217 validity: Validity,
218 ) -> Self {
219 #[cfg(debug_assertions)]
220 {
221 Self::validate(&values, values_type, &validity)
222 .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
223 }
224
225 Self {
226 values,
227 values_type,
228 dtype: DType::Decimal(decimal_dtype, validity.nullability()),
229 validity,
230 stats_set: Default::default(),
231 }
232 }
233
234 fn validate(
238 buffer: &BufferHandle,
239 values_type: DecimalType,
240 validity: &Validity,
241 ) -> VortexResult<()> {
242 if let Some(validity_len) = validity.maybe_len() {
243 let expected_len = values_type.byte_width() * validity_len;
244 vortex_ensure!(
245 buffer.len() == expected_len,
246 InvalidArgument: "expected buffer of size {} bytes, was {} bytes",
247 expected_len,
248 buffer.len(),
249 );
250 }
251
252 Ok(())
253 }
254
255 pub unsafe fn new_unchecked_from_byte_buffer(
265 byte_buffer: ByteBuffer,
266 values_type: DecimalType,
267 decimal_dtype: DecimalDType,
268 validity: Validity,
269 ) -> Self {
270 unsafe {
272 Self::new_unchecked_handle(
273 BufferHandle::new_host(byte_buffer),
274 values_type,
275 decimal_dtype,
276 validity,
277 )
278 }
279 }
280
281 pub fn into_parts(self) -> DecimalArrayParts {
282 let decimal_dtype = self.dtype.into_decimal_opt().vortex_expect("cannot fail");
283
284 DecimalArrayParts {
285 decimal_dtype,
286 values: self.values,
287 values_type: self.values_type,
288 validity: self.validity,
289 }
290 }
291
292 pub fn buffer_handle(&self) -> &BufferHandle {
294 &self.values
295 }
296
297 pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
298 if self.values_type != T::DECIMAL_TYPE {
299 vortex_panic!(
300 "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
301 T::DECIMAL_TYPE,
302 self.values_type,
303 );
304 }
305 Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
306 }
307
308 pub fn decimal_dtype(&self) -> DecimalDType {
310 if let DType::Decimal(decimal_dtype, _) = self.dtype {
311 decimal_dtype
312 } else {
313 vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
314 }
315 }
316
317 pub fn values_type(&self) -> DecimalType {
319 self.values_type
320 }
321
322 pub fn precision(&self) -> u8 {
323 self.decimal_dtype().precision()
324 }
325
326 pub fn scale(&self) -> i8 {
327 self.decimal_dtype().scale()
328 }
329
330 pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
331 iter: I,
332 decimal_dtype: DecimalDType,
333 ) -> Self {
334 let iter = iter.into_iter();
335
336 Self::new(
337 BufferMut::from_iter(iter).freeze(),
338 decimal_dtype,
339 Validity::NonNullable,
340 )
341 }
342
343 pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
344 iter: I,
345 decimal_dtype: DecimalDType,
346 ) -> Self {
347 let iter = iter.into_iter();
348 let mut values = BufferMut::with_capacity(iter.size_hint().0);
349 let mut validity = BitBufferMut::with_capacity(values.capacity());
350
351 for i in iter {
352 match i {
353 None => {
354 validity.append(false);
355 values.push(T::default());
356 }
357 Some(e) => {
358 validity.append(true);
359 values.push(e);
360 }
361 }
362 }
363 Self::new(
364 values.freeze(),
365 decimal_dtype,
366 Validity::from(validity.freeze()),
367 )
368 }
369
370 #[expect(
371 clippy::cognitive_complexity,
372 reason = "complexity from nested match_each_* macros"
373 )]
374 pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
375 let offset = patches.offset();
376 let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
377 let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
378
379 let patched_validity = self.validity().clone().patch(
380 self.len(),
381 offset,
382 &patch_indices.clone().into_array(),
383 patch_values.validity(),
384 ctx,
385 )?;
386 assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
387
388 Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
389 let patch_indices = patch_indices.as_slice::<I>();
390 match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
391 let patch_values = patch_values.buffer::<PatchDVT>();
392 match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
393 let buffer = self.buffer::<ValuesDVT>().into_mut();
394 patch_typed(
395 buffer,
396 self.decimal_dtype(),
397 patch_indices,
398 offset,
399 patch_values,
400 patched_validity,
401 )
402 })
403 })
404 }))
405 }
406}
407
408fn patch_typed<I, ValuesDVT, PatchDVT>(
409 mut buffer: BufferMut<ValuesDVT>,
410 decimal_dtype: DecimalDType,
411 patch_indices: &[I],
412 patch_indices_offset: usize,
413 patch_values: Buffer<PatchDVT>,
414 patched_validity: Validity,
415) -> DecimalArray
416where
417 I: IntegerPType,
418 PatchDVT: NativeDecimalType,
419 ValuesDVT: NativeDecimalType,
420{
421 if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
422 vortex_panic!(
423 "patch_typed: {:?} cannot represent every value in {}.",
424 ValuesDVT::DECIMAL_TYPE,
425 decimal_dtype
426 )
427 }
428
429 for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
430 buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
431 "values of a given DecimalDType are representable in all compatible NativeDecimalType",
432 );
433 }
434
435 DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
436}