vortex_array/arrays/decimal/
array.rs1use itertools::Itertools;
5use vortex_buffer::BitBufferMut;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::BigCast;
10use vortex_dtype::DType;
11use vortex_dtype::DecimalDType;
12use vortex_dtype::DecimalType;
13use vortex_dtype::IntegerPType;
14use vortex_dtype::NativeDecimalType;
15use vortex_dtype::match_each_decimal_value_type;
16use vortex_dtype::match_each_integer_ptype;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_ensure;
20use vortex_error::vortex_panic;
21
22use crate::ToCanonical;
23use crate::buffer::BufferHandle;
24use crate::patches::Patches;
25use crate::stats::ArrayStats;
26use crate::validity::Validity;
27use crate::vtable::ValidityHelper;
28
29#[derive(Clone, Debug)]
87pub struct DecimalArray {
88 pub(super) dtype: DType,
89 pub(super) values: BufferHandle,
90 pub(super) values_type: DecimalType,
91 pub(super) validity: Validity,
92 pub(super) stats_set: ArrayStats,
93}
94
95pub struct DecimalArrayParts {
96 pub decimal_dtype: DecimalDType,
97 pub values: BufferHandle,
98 pub values_type: DecimalType,
99 pub validity: Validity,
100}
101
102impl DecimalArray {
103 pub fn new<T: NativeDecimalType>(
110 buffer: Buffer<T>,
111 decimal_dtype: DecimalDType,
112 validity: Validity,
113 ) -> Self {
114 Self::try_new(buffer, decimal_dtype, validity)
115 .vortex_expect("DecimalArray construction failed")
116 }
117
118 pub fn new_handle(
126 values: BufferHandle,
127 values_type: DecimalType,
128 decimal_dtype: DecimalDType,
129 validity: Validity,
130 ) -> Self {
131 Self::try_new_handle(values, values_type, decimal_dtype, validity)
132 .vortex_expect("DecimalArray construction failed")
133 }
134
135 pub fn try_new<T: NativeDecimalType>(
144 buffer: Buffer<T>,
145 decimal_dtype: DecimalDType,
146 validity: Validity,
147 ) -> VortexResult<Self> {
148 let values = BufferHandle::new_host(buffer.into_byte_buffer());
149 let values_type = T::DECIMAL_TYPE;
150
151 Self::try_new_handle(values, values_type, decimal_dtype, validity)
152 }
153
154 pub fn try_new_handle(
162 values: BufferHandle,
163 values_type: DecimalType,
164 decimal_dtype: DecimalDType,
165 validity: Validity,
166 ) -> VortexResult<Self> {
167 Self::validate(&values, values_type, &validity)?;
168
169 Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype, validity) })
171 }
172
173 pub unsafe fn new_unchecked<T: NativeDecimalType>(
187 buffer: Buffer<T>,
188 decimal_dtype: DecimalDType,
189 validity: Validity,
190 ) -> Self {
191 unsafe {
193 Self::new_unchecked_handle(
194 BufferHandle::new_host(buffer.into_byte_buffer()),
195 T::DECIMAL_TYPE,
196 decimal_dtype,
197 validity,
198 )
199 }
200 }
201
202 pub unsafe fn new_unchecked_handle(
212 values: BufferHandle,
213 values_type: DecimalType,
214 decimal_dtype: DecimalDType,
215 validity: Validity,
216 ) -> Self {
217 #[cfg(debug_assertions)]
218 {
219 Self::validate(&values, values_type, &validity)
220 .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
221 }
222
223 Self {
224 values,
225 values_type,
226 dtype: DType::Decimal(decimal_dtype, validity.nullability()),
227 validity,
228 stats_set: Default::default(),
229 }
230 }
231
232 fn validate(
236 buffer: &BufferHandle,
237 values_type: DecimalType,
238 validity: &Validity,
239 ) -> VortexResult<()> {
240 if let Some(validity_len) = validity.maybe_len() {
241 let expected_len = values_type.byte_width() * validity_len;
242 vortex_ensure!(
243 buffer.len() == expected_len,
244 InvalidArgument: "expected buffer of size {} bytes, was {} bytes",
245 expected_len,
246 buffer.len(),
247 );
248 }
249
250 Ok(())
251 }
252
253 pub unsafe fn new_unchecked_from_byte_buffer(
263 byte_buffer: ByteBuffer,
264 values_type: DecimalType,
265 decimal_dtype: DecimalDType,
266 validity: Validity,
267 ) -> Self {
268 unsafe {
270 Self::new_unchecked_handle(
271 BufferHandle::new_host(byte_buffer),
272 values_type,
273 decimal_dtype,
274 validity,
275 )
276 }
277 }
278
279 pub fn into_parts(self) -> DecimalArrayParts {
280 let decimal_dtype = self.dtype.into_decimal_opt().vortex_expect("cannot fail");
281
282 DecimalArrayParts {
283 decimal_dtype,
284 values: self.values,
285 values_type: self.values_type,
286 validity: self.validity,
287 }
288 }
289
290 pub fn buffer_handle(&self) -> &BufferHandle {
292 &self.values
293 }
294
295 pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
296 if self.values_type != T::DECIMAL_TYPE {
297 vortex_panic!(
298 "Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
299 T::DECIMAL_TYPE,
300 self.values_type,
301 );
302 }
303 Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
304 }
305
306 pub fn decimal_dtype(&self) -> DecimalDType {
308 if let DType::Decimal(decimal_dtype, _) = self.dtype {
309 decimal_dtype
310 } else {
311 vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
312 }
313 }
314
315 pub fn values_type(&self) -> DecimalType {
317 self.values_type
318 }
319
320 pub fn precision(&self) -> u8 {
321 self.decimal_dtype().precision()
322 }
323
324 pub fn scale(&self) -> i8 {
325 self.decimal_dtype().scale()
326 }
327
328 pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
329 iter: I,
330 decimal_dtype: DecimalDType,
331 ) -> Self {
332 let iter = iter.into_iter();
333
334 Self::new(
335 BufferMut::from_iter(iter).freeze(),
336 decimal_dtype,
337 Validity::NonNullable,
338 )
339 }
340
341 pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
342 iter: I,
343 decimal_dtype: DecimalDType,
344 ) -> Self {
345 let iter = iter.into_iter();
346 let mut values = BufferMut::with_capacity(iter.size_hint().0);
347 let mut validity = BitBufferMut::with_capacity(values.capacity());
348
349 for i in iter {
350 match i {
351 None => {
352 validity.append(false);
353 values.push(T::default());
354 }
355 Some(e) => {
356 validity.append(true);
357 values.push(e);
358 }
359 }
360 }
361 Self::new(
362 values.freeze(),
363 decimal_dtype,
364 Validity::from(validity.freeze()),
365 )
366 }
367
368 #[expect(
369 clippy::cognitive_complexity,
370 reason = "complexity from nested match_each_* macros"
371 )]
372 pub fn patch(self, patches: &Patches) -> VortexResult<Self> {
373 let offset = patches.offset();
374 let patch_indices = patches.indices().to_primitive();
375 let patch_values = patches.values().to_decimal();
376
377 let patched_validity = self.validity().clone().patch(
378 self.len(),
379 offset,
380 patch_indices.as_ref(),
381 patch_values.validity(),
382 )?;
383 assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
384
385 Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
386 let patch_indices = patch_indices.as_slice::<I>();
387 match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
388 let patch_values = patch_values.buffer::<PatchDVT>();
389 match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
390 let buffer = self.buffer::<ValuesDVT>().into_mut();
391 patch_typed(
392 buffer,
393 self.decimal_dtype(),
394 patch_indices,
395 offset,
396 patch_values,
397 patched_validity,
398 )
399 })
400 })
401 }))
402 }
403}
404
405fn patch_typed<I, ValuesDVT, PatchDVT>(
406 mut buffer: BufferMut<ValuesDVT>,
407 decimal_dtype: DecimalDType,
408 patch_indices: &[I],
409 patch_indices_offset: usize,
410 patch_values: Buffer<PatchDVT>,
411 patched_validity: Validity,
412) -> DecimalArray
413where
414 I: IntegerPType,
415 PatchDVT: NativeDecimalType,
416 ValuesDVT: NativeDecimalType,
417{
418 if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
419 vortex_panic!(
420 "patch_typed: {:?} cannot represent every value in {}.",
421 ValuesDVT::DECIMAL_TYPE,
422 decimal_dtype
423 )
424 }
425
426 for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
427 buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
428 "values of a given DecimalDType are representable in all compatible NativeDecimalType",
429 );
430 }
431
432 DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
433}