vortex_array/arrays/decimal/
patch.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::ArrowNativeType;
5use itertools::Itertools as _;
6use vortex_buffer::{Buffer, BufferMut};
7use vortex_dtype::{DecimalDType, NativePType, match_each_integer_ptype};
8use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
9use vortex_scalar::{BigCast, NativeDecimalType, match_each_decimal_value_type};
10
11use super::{DecimalArray, compatible_storage_type};
12use crate::ToCanonical as _;
13use crate::patches::Patches;
14use crate::validity::Validity;
15use crate::vtable::ValidityHelper;
16
17impl DecimalArray {
18    #[allow(clippy::cognitive_complexity)]
19    pub fn patch(self, patches: &Patches) -> VortexResult<Self> {
20        let offset = patches.offset();
21        let patch_indices = patches.indices().to_primitive()?;
22        let patch_values = patches.values().to_decimal()?;
23
24        let patched_validity = self.validity().clone().patch(
25            self.len(),
26            offset,
27            patch_indices.as_ref(),
28            patch_values.validity(),
29        )?;
30        assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
31
32        match_each_integer_ptype!(patch_indices.ptype(), |I| {
33            let patch_indices = patch_indices.as_slice::<I>();
34            match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
35                let patch_values = patch_values.buffer::<PatchDVT>();
36                match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
37                    let buffer = self.buffer::<ValuesDVT>().into_mut();
38                    patch_typed(
39                        buffer,
40                        self.decimal_dtype(),
41                        patch_indices,
42                        offset,
43                        patch_values,
44                        patched_validity,
45                    )
46                })
47            })
48        })
49    }
50}
51
52fn patch_typed<I, ValuesDVT, PatchDVT>(
53    mut buffer: BufferMut<ValuesDVT>,
54    decimal_dtype: DecimalDType,
55    patch_indices: &[I],
56    patch_indices_offset: usize,
57    patch_values: Buffer<PatchDVT>,
58    patched_validity: Validity,
59) -> VortexResult<DecimalArray>
60where
61    I: NativePType + ArrowNativeType,
62    PatchDVT: NativeDecimalType,
63    ValuesDVT: NativeDecimalType,
64{
65    if !compatible_storage_type(ValuesDVT::VALUES_TYPE, decimal_dtype) {
66        vortex_bail!(
67            "patch_typed: {:?} cannot represent every value in {}.",
68            ValuesDVT::VALUES_TYPE,
69            decimal_dtype
70        )
71    }
72
73    for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
74        buffer[idx.as_usize() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
75            "values of a given DecimalDType are representable in all compatible NativeDecimalType",
76        );
77    }
78
79    Ok(DecimalArray::new(
80        buffer.freeze(),
81        decimal_dtype,
82        patched_validity,
83    ))
84}