vortex_array/arrays/decimal/
patch.rs1use arrow_buffer::ArrowNativeType;
5use itertools::Itertools as _;
6use vortex_buffer::{Buffer, BufferMut};
7use vortex_dtype::{DecimalDType, NativePType, match_each_integer_ptype};
8use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
9use vortex_scalar::{BigCast, NativeDecimalType, match_each_decimal_value_type};
10
11use super::{DecimalArray, compatible_storage_type};
12use crate::ToCanonical as _;
13use crate::patches::Patches;
14use crate::validity::Validity;
15use crate::vtable::ValidityHelper;
16
17impl DecimalArray {
18 #[allow(clippy::cognitive_complexity)]
19 pub fn patch(self, patches: &Patches) -> VortexResult<Self> {
20 let offset = patches.offset();
21 let patch_indices = patches.indices().to_primitive()?;
22 let patch_values = patches.values().to_decimal()?;
23
24 let patched_validity = self.validity().clone().patch(
25 self.len(),
26 offset,
27 patch_indices.as_ref(),
28 patch_values.validity(),
29 )?;
30 assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
31
32 match_each_integer_ptype!(patch_indices.ptype(), |I| {
33 let patch_indices = patch_indices.as_slice::<I>();
34 match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
35 let patch_values = patch_values.buffer::<PatchDVT>();
36 match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
37 let buffer = self.buffer::<ValuesDVT>().into_mut();
38 patch_typed(
39 buffer,
40 self.decimal_dtype(),
41 patch_indices,
42 offset,
43 patch_values,
44 patched_validity,
45 )
46 })
47 })
48 })
49 }
50}
51
52fn patch_typed<I, ValuesDVT, PatchDVT>(
53 mut buffer: BufferMut<ValuesDVT>,
54 decimal_dtype: DecimalDType,
55 patch_indices: &[I],
56 patch_indices_offset: usize,
57 patch_values: Buffer<PatchDVT>,
58 patched_validity: Validity,
59) -> VortexResult<DecimalArray>
60where
61 I: NativePType + ArrowNativeType,
62 PatchDVT: NativeDecimalType,
63 ValuesDVT: NativeDecimalType,
64{
65 if !compatible_storage_type(ValuesDVT::VALUES_TYPE, decimal_dtype) {
66 vortex_bail!(
67 "patch_typed: {:?} cannot represent every value in {}.",
68 ValuesDVT::VALUES_TYPE,
69 decimal_dtype
70 )
71 }
72
73 for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
74 buffer[idx.as_usize() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
75 "values of a given DecimalDType are representable in all compatible NativeDecimalType",
76 );
77 }
78
79 Ok(DecimalArray::new(
80 buffer.freeze(),
81 decimal_dtype,
82 patched_validity,
83 ))
84}