Skip to main content

vortex_array/arrays/primitive/array/
patch.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_dtype::IntegerPType;
7use vortex_dtype::NativePType;
8use vortex_dtype::UnsignedPType;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_dtype::match_each_native_ptype;
11use vortex_error::VortexResult;
12
13use crate::Array;
14use crate::arrays::PrimitiveArray;
15use crate::patches::PATCH_CHUNK_SIZE;
16use crate::patches::Patches;
17use crate::validity::Validity;
18use crate::vtable::ValidityHelper;
19
20impl PrimitiveArray {
21    pub fn patch(self, patches: &Patches) -> VortexResult<Self> {
22        let patch_indices = patches.indices().to_canonical()?.into_primitive();
23        let patch_values = patches.values().to_canonical()?.into_primitive();
24
25        let patched_validity = self.validity().clone().patch(
26            self.len(),
27            patches.offset(),
28            patch_indices.as_ref(),
29            patch_values.validity(),
30        )?;
31        Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
32            match_each_native_ptype!(self.ptype(), |T| {
33                self.patch_typed::<T, I>(
34                    patch_indices,
35                    patches.offset(),
36                    patch_values,
37                    patched_validity,
38                )
39            })
40        }))
41    }
42
43    fn patch_typed<T, I>(
44        self,
45        patch_indices: PrimitiveArray,
46        patch_indices_offset: usize,
47        patch_values: PrimitiveArray,
48        patched_validity: Validity,
49    ) -> Self
50    where
51        T: NativePType,
52        I: IntegerPType,
53    {
54        let mut own_values = self.into_buffer_mut::<T>();
55
56        let patch_indices = patch_indices.as_slice::<I>();
57        let patch_values = patch_values.as_slice::<T>();
58        for (idx, value) in itertools::zip_eq(patch_indices, patch_values) {
59            own_values[idx.as_() - patch_indices_offset] = *value;
60        }
61        Self::new(own_values, patched_validity)
62    }
63}
64
65/// Computes the index range for a chunk, accounting for slice offset.
66///
67/// # Arguments
68///
69/// * `chunk_idx` - Index of the chunk
70/// * `offset` - Offset from slice
71/// * `array_len` - Length of the sliced array
72#[inline]
73pub fn chunk_range(chunk_idx: usize, offset: usize, array_len: usize) -> Range<usize> {
74    let offset_in_chunk = offset % PATCH_CHUNK_SIZE;
75    let local_start = (chunk_idx * PATCH_CHUNK_SIZE).saturating_sub(offset_in_chunk);
76    let local_end = ((chunk_idx + 1) * PATCH_CHUNK_SIZE)
77        .saturating_sub(offset_in_chunk)
78        .min(array_len);
79    local_start..local_end
80}
81
82/// Patches a chunk of decoded values.
83///
84/// # Arguments
85///
86/// * `decoded_values` - Mutable slice of decoded values to be patched
87/// * `patches_indices` - Indices indicating which positions to patch
88/// * `patches_values` - Values to apply at the patched indices
89/// * `patches_offset` - Absolute position where the slice starts
90/// * `chunk_offsets_slice` - Slice containing offsets for each chunk
91/// * `chunk_idx` - Index of the chunk to patch
92/// * `offset_within_chunk` - Number of patches to skip at the start of the first chunk
93#[inline]
94pub fn patch_chunk<T, I, C>(
95    decoded_values: &mut [T],
96    patches_indices: &[I],
97    patches_values: &[T],
98    patches_offset: usize,
99    chunk_offsets_slice: &[C],
100    chunk_idx: usize,
101    offset_within_chunk: usize,
102) where
103    T: NativePType,
104    I: UnsignedPType,
105    C: UnsignedPType,
106{
107    // Compute base_offset from the first chunk offset.
108    let base_offset: usize = chunk_offsets_slice[0].as_();
109
110    // Use the same logic as patches slice implementation for calculating patch ranges.
111    let patches_start_idx =
112        (chunk_offsets_slice[chunk_idx].as_() - base_offset).saturating_sub(offset_within_chunk);
113    let patches_end_idx = if chunk_idx + 1 < chunk_offsets_slice.len() {
114        chunk_offsets_slice[chunk_idx + 1].as_() - base_offset - offset_within_chunk
115    } else {
116        patches_indices.len()
117    };
118
119    let chunk_start = chunk_range(chunk_idx, patches_offset, /* ignore */ usize::MAX).start;
120
121    for patches_idx in patches_start_idx..patches_end_idx {
122        let chunk_relative_index =
123            (patches_indices[patches_idx].as_() - patches_offset) - chunk_start;
124        decoded_values[chunk_relative_index] = patches_values[patches_idx];
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use vortex_buffer::buffer;
131
132    use super::*;
133    use crate::ToCanonical;
134    use crate::assert_arrays_eq;
135    use crate::validity::Validity;
136
137    #[test]
138    fn patch_sliced() {
139        let input = PrimitiveArray::new(buffer![2u32; 10], Validity::AllValid);
140        let sliced = input.slice(2..8).unwrap();
141        assert_arrays_eq!(
142            sliced.to_primitive(),
143            PrimitiveArray::new(buffer![2u32; 6], Validity::AllValid)
144        );
145    }
146}