Skip to main content

vortex_array/arrays/primitive/array/
patch.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_error::VortexResult;
7
8use crate::ExecutionCtx;
9use crate::IntoArray;
10use crate::arrays::PrimitiveArray;
11use crate::dtype::IntegerPType;
12use crate::dtype::NativePType;
13use crate::dtype::UnsignedPType;
14use crate::match_each_integer_ptype;
15use crate::match_each_native_ptype;
16use crate::patches::PATCH_CHUNK_SIZE;
17use crate::patches::Patches;
18use crate::validity::Validity;
19
20impl PrimitiveArray {
21    pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
22        let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
23        let patch_values = patches.values().clone().execute::<PrimitiveArray>(ctx)?;
24
25        let patch_validity = patch_values.validity()?;
26        let patched_validity = self.validity()?.patch(
27            self.len(),
28            patches.offset(),
29            &patch_indices.clone().into_array(),
30            &patch_validity,
31            ctx,
32        )?;
33        Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
34            match_each_native_ptype!(self.ptype(), |T| {
35                self.patch_typed::<T, I>(
36                    patch_indices,
37                    patches.offset(),
38                    patch_values,
39                    patched_validity,
40                )
41            })
42        }))
43    }
44
45    fn patch_typed<T, I>(
46        self,
47        patch_indices: PrimitiveArray,
48        patch_indices_offset: usize,
49        patch_values: PrimitiveArray,
50        patched_validity: Validity,
51    ) -> Self
52    where
53        T: NativePType,
54        I: IntegerPType,
55    {
56        let mut own_values = self.into_buffer_mut::<T>();
57
58        let patch_indices = patch_indices.as_slice::<I>();
59        let patch_values = patch_values.as_slice::<T>();
60        for (idx, value) in itertools::zip_eq(patch_indices, patch_values) {
61            own_values[idx.as_() - patch_indices_offset] = *value;
62        }
63        Self::new(own_values, patched_validity)
64    }
65}
66
67/// Computes the index range for a chunk, accounting for slice offset.
68///
69/// # Arguments
70///
71/// * `chunk_idx` - Index of the chunk
72/// * `offset` - Offset from slice
73/// * `array_len` - Length of the sliced array
74#[inline]
75pub fn chunk_range(chunk_idx: usize, offset: usize, array_len: usize) -> Range<usize> {
76    let offset_in_chunk = offset % PATCH_CHUNK_SIZE;
77    let local_start = (chunk_idx * PATCH_CHUNK_SIZE).saturating_sub(offset_in_chunk);
78    let local_end = ((chunk_idx + 1) * PATCH_CHUNK_SIZE)
79        .saturating_sub(offset_in_chunk)
80        .min(array_len);
81    local_start..local_end
82}
83
84/// Patches a chunk of decoded values.
85///
86/// # Arguments
87///
88/// * `decoded_values` - Mutable slice of decoded values to be patched
89/// * `patches_indices` - Indices indicating which positions to patch
90/// * `patches_values` - Values to apply at the patched indices
91/// * `patches_offset` - Absolute position where the slice starts
92/// * `chunk_offsets_slice` - Slice containing offsets for each chunk
93/// * `chunk_idx` - Index of the chunk to patch
94/// * `offset_within_chunk` - Number of patches to skip at the start of the first chunk
95pub fn patch_chunk<T, I, C>(
96    decoded_values: &mut [T],
97    patches_indices: &[I],
98    patches_values: &[T],
99    patches_offset: usize,
100    chunk_offsets_slice: &[C],
101    chunk_idx: usize,
102    offset_within_chunk: usize,
103) where
104    T: NativePType,
105    I: UnsignedPType,
106    C: UnsignedPType,
107{
108    // Compute base_offset from the first chunk offset.
109    let base_offset: usize = chunk_offsets_slice[0].as_();
110
111    // Use the same logic as patches slice implementation for calculating patch ranges.
112    let patches_start_idx =
113        (chunk_offsets_slice[chunk_idx].as_() - base_offset).saturating_sub(offset_within_chunk);
114    // Clamp: chunk_offsets are sliced at chunk granularity but patches at element
115    // granularity, so the next chunk offset may exceed the actual patches length.
116    let patches_end_idx = if chunk_idx + 1 < chunk_offsets_slice.len() {
117        (chunk_offsets_slice[chunk_idx + 1].as_() - base_offset)
118            .saturating_sub(offset_within_chunk)
119            .min(patches_indices.len())
120    } else {
121        patches_indices.len()
122    };
123
124    let chunk_start = chunk_range(chunk_idx, patches_offset, /* ignore */ usize::MAX).start;
125
126    for patches_idx in patches_start_idx..patches_end_idx {
127        let chunk_relative_index =
128            (patches_indices[patches_idx].as_() - patches_offset) - chunk_start;
129        decoded_values[chunk_relative_index] = patches_values[patches_idx];
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use vortex_buffer::buffer;
136
137    use super::*;
138    use crate::ToCanonical;
139    use crate::assert_arrays_eq;
140    use crate::validity::Validity;
141
142    /// Regression: patch_chunk must not OOB when chunk_offsets (chunk granularity)
143    /// reference more patches than patches_indices (element granularity) contains.
144    #[test]
145    fn patch_chunk_no_oob_on_mid_chunk_slice() {
146        let mut decoded_values = vec![0.0f64; PATCH_CHUNK_SIZE];
147        // 10 patches, but chunk_offsets claim 15 exist past offset adjustment.
148        let patches_indices: Vec<u64> = (0..10)
149            .map(|i| (PATCH_CHUNK_SIZE as u64) + i * 10)
150            .collect();
151        let patches_values: Vec<f64> = (0..10).map(|i| (i + 1) as f64 * 100.0).collect();
152        // chunk_offsets [5, 12, 20]: for chunk_idx=1 with offset_within_chunk=3,
153        // unclamped end = (20-5)-3 = 12, which exceeds patches len of 10.
154        let chunk_offsets: Vec<u32> = vec![5, 12, 20];
155
156        patch_chunk(
157            &mut decoded_values,
158            &patches_indices,
159            &patches_values,
160            0,
161            &chunk_offsets,
162            1,
163            3,
164        );
165
166        // Spot-check: patch index 4 (first in range) should be applied.
167        assert_ne!(
168            decoded_values[usize::try_from(patches_indices[4]).unwrap() - PATCH_CHUNK_SIZE],
169            0.0
170        );
171    }
172
173    #[test]
174    fn patch_sliced() {
175        let input = PrimitiveArray::new(buffer![2u32; 10], Validity::AllValid);
176        let sliced = input.slice(2..8).unwrap();
177        assert_arrays_eq!(
178            sliced.to_primitive(),
179            PrimitiveArray::new(buffer![2u32; 6], Validity::AllValid)
180        );
181    }
182}