Skip to main content

vortex_array/arrays/patched/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::ops::Range;
7
8use vortex_buffer::Buffer;
9use vortex_buffer::BufferMut;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::ArraySlots;
16use crate::Canonical;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::LEGACY_SESSION;
20use crate::VortexSessionExecute;
21use crate::array::Array;
22use crate::array::ArrayParts;
23use crate::array::TypedArrayRef;
24use crate::array_slots;
25use crate::arrays::Patched;
26use crate::arrays::PrimitiveArray;
27use crate::arrays::patched::TransposedPatches;
28use crate::arrays::patched::patch_lanes;
29use crate::buffer::BufferHandle;
30use crate::dtype::DType;
31use crate::dtype::IntegerPType;
32use crate::dtype::NativePType;
33use crate::dtype::PType;
34use crate::match_each_native_ptype;
35use crate::match_each_unsigned_integer_ptype;
36use crate::patches::Patches;
37use crate::validity::Validity;
38
39#[derive(Debug, Clone)]
40pub struct PatchedData {
41    /// Number of lanes the patch indices and values have been split into. Each of the `n_chunks`
42    /// of 1024 values is split into `n_lanes` lanes horizontally, each lane having 1024 / n_lanes
43    /// values that might be patched.
44    pub(super) n_lanes: usize,
45
46    /// The offset into that first chunk that is considered in bounds.
47    ///
48    /// The patch indices of the first chunk less than `offset` should be skipped, and the offset
49    /// should be subtracted out of the remaining offsets to get their final position in the
50    /// executed array.
51    pub(super) offset: usize,
52}
53
54#[array_slots(Patched)]
55pub struct PatchedSlots {
56    /// The inner array containing the base unpatched values.
57    pub inner: ArrayRef,
58    /// The lane offsets array for locating patches within lanes.
59    pub lane_offsets: ArrayRef,
60    /// The indices of patched (exception) values.
61    pub patch_indices: ArrayRef,
62    /// The patched (exception) values at the corresponding indices.
63    pub patch_values: ArrayRef,
64}
65
66impl Display for PatchedData {
67    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68        write!(f, "n_lanes: {}, offset: {}", self.n_lanes, self.offset)
69    }
70}
71
72impl PatchedData {
73    pub(crate) fn validate(
74        &self,
75        dtype: &DType,
76        len: usize,
77        slots: &PatchedSlotsView,
78    ) -> VortexResult<()> {
79        vortex_ensure!(
80            slots.inner.dtype() == dtype,
81            "PatchedArray base dtype {} does not match outer dtype {}",
82            slots.inner.dtype(),
83            dtype
84        );
85        vortex_ensure!(
86            slots.inner.len() == len,
87            "PatchedArray base len {} does not match outer len {}",
88            slots.inner.len(),
89            len
90        );
91        vortex_ensure!(
92            slots.patch_indices.len() == slots.patch_values.len(),
93            "PatchedArray patch indices len {} does not match patch values len {}",
94            slots.patch_indices.len(),
95            slots.patch_values.len()
96        );
97        Ok(())
98    }
99}
100
101pub trait PatchedArrayExt: PatchedArraySlotsExt {
102    #[inline]
103    fn n_lanes(&self) -> usize {
104        self.n_lanes
105    }
106
107    #[inline]
108    fn offset(&self) -> usize {
109        self.offset
110    }
111
112    #[inline]
113    fn lane_range(&self, chunk: usize, lane: usize) -> VortexResult<Range<usize>> {
114        assert!(chunk * 1024 <= self.as_ref().len() + self.offset());
115        assert!(lane < self.n_lanes());
116
117        let start = self.lane_offsets().execute_scalar(
118            chunk * self.n_lanes() + lane,
119            &mut LEGACY_SESSION.create_execution_ctx(),
120        )?;
121        let stop = self.lane_offsets().execute_scalar(
122            chunk * self.n_lanes() + lane + 1,
123            &mut LEGACY_SESSION.create_execution_ctx(),
124        )?;
125
126        let start = start
127            .as_primitive()
128            .as_::<usize>()
129            .ok_or_else(|| vortex_err!("could not cast lane_offset to usize"))?;
130
131        let stop = stop
132            .as_primitive()
133            .as_::<usize>()
134            .ok_or_else(|| vortex_err!("could not cast lane_offset to usize"))?;
135
136        Ok(start..stop)
137    }
138
139    fn slice_chunks(&self, chunks: Range<usize>) -> VortexResult<Array<Patched>> {
140        let lane_offsets_start = chunks.start * self.n_lanes();
141        let lane_offsets_stop = chunks.end * self.n_lanes() + 1;
142
143        let sliced_lane_offsets = self
144            .lane_offsets()
145            .slice(lane_offsets_start..lane_offsets_stop)?;
146        let indices = self.patch_indices().clone();
147        let values = self.patch_values().clone();
148
149        let begin = (chunks.start * 1024).saturating_sub(self.offset());
150        let end = (chunks.end * 1024)
151            .saturating_sub(self.offset())
152            .min(self.as_ref().len());
153
154        let offset = if chunks.start == 0 { self.offset() } else { 0 };
155        let inner = self.inner().slice(begin..end)?;
156        let len = inner.len();
157        let dtype = self.as_ref().dtype().clone();
158        let slots = PatchedSlots {
159            inner,
160            lane_offsets: sliced_lane_offsets,
161            patch_indices: indices,
162            patch_values: values,
163        }
164        .into_slots();
165
166        Ok(unsafe { Patched::new_unchecked(dtype, len, slots, self.n_lanes(), offset) })
167    }
168}
169
170impl<T: TypedArrayRef<Patched>> PatchedArrayExt for T {}
171
172impl Patched {
173    pub fn from_array_and_patches(
174        inner: ArrayRef,
175        patches: &Patches,
176        ctx: &mut ExecutionCtx,
177    ) -> VortexResult<Array<Patched>> {
178        vortex_ensure!(
179            inner.dtype().eq_with_nullability_superset(patches.dtype()),
180            "array DType must match patches DType"
181        );
182
183        vortex_ensure!(
184            inner.dtype().is_primitive(),
185            "Creating PatchedArray from Patches only supported for primitive arrays"
186        );
187
188        vortex_ensure!(
189            patches.num_patches() <= u32::MAX as usize,
190            "PatchedArray does not support > u32::MAX patch values"
191        );
192
193        vortex_ensure!(
194            patches.values().all_valid(ctx)?,
195            "PatchedArray cannot be built from Patches with nulls"
196        );
197
198        let values_ptype = patches.dtype().as_ptype();
199
200        let TransposedPatches {
201            n_lanes,
202            lane_offsets,
203            indices,
204            values,
205        } = transpose_patches(patches, ctx)?;
206
207        let lane_offsets = PrimitiveArray::from_buffer_handle(
208            BufferHandle::new_host(lane_offsets),
209            PType::U32,
210            Validity::NonNullable,
211        )
212        .into_array();
213        let indices = PrimitiveArray::from_buffer_handle(
214            BufferHandle::new_host(indices),
215            PType::U16,
216            Validity::NonNullable,
217        )
218        .into_array();
219        let values = PrimitiveArray::from_buffer_handle(
220            BufferHandle::new_host(values),
221            values_ptype,
222            Validity::NonNullable,
223        )
224        .into_array();
225
226        let dtype = inner.dtype().clone();
227        let len = inner.len();
228        let slots = PatchedSlots {
229            inner,
230            lane_offsets,
231            patch_indices: indices,
232            patch_values: values,
233        }
234        .into_slots();
235        Ok(unsafe { Self::new_unchecked(dtype, len, slots, n_lanes, 0) })
236    }
237
238    pub(crate) unsafe fn new_unchecked(
239        dtype: DType,
240        len: usize,
241        slots: ArraySlots,
242        n_lanes: usize,
243        offset: usize,
244    ) -> Array<Patched> {
245        unsafe {
246            Array::from_parts_unchecked(
247                ArrayParts::new(Patched, dtype, len, PatchedData { n_lanes, offset })
248                    .with_slots(slots),
249            )
250        }
251    }
252}
253
254/// Transpose a set of patches from the default sorted layout into the data parallel layout.
255fn transpose_patches(patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<TransposedPatches> {
256    let array_len = patches.array_len();
257    let offset = patches.offset();
258
259    let indices = patches
260        .indices()
261        .clone()
262        .execute::<Canonical>(ctx)?
263        .into_primitive();
264
265    let values = patches
266        .values()
267        .clone()
268        .execute::<Canonical>(ctx)?
269        .into_primitive();
270
271    let indices_ptype = indices.ptype();
272    let values_ptype = values.ptype();
273
274    let indices = indices.buffer_handle().clone().unwrap_host();
275    let values = values.buffer_handle().clone().unwrap_host();
276
277    match_each_unsigned_integer_ptype!(indices_ptype, |I| {
278        match_each_native_ptype!(values_ptype, |V| {
279            let indices: Buffer<I> = Buffer::from_byte_buffer(indices);
280            let values: Buffer<V> = Buffer::from_byte_buffer(values);
281
282            Ok(transpose(
283                indices.as_slice(),
284                values.as_slice(),
285                offset,
286                array_len,
287            ))
288        })
289    })
290}
291
292#[expect(clippy::cast_possible_truncation)]
293fn transpose<I: IntegerPType, V: NativePType>(
294    indices_in: &[I],
295    values_in: &[V],
296    offset: usize,
297    array_len: usize,
298) -> TransposedPatches {
299    // Total number of slots is number of chunks times number of lanes.
300    let n_chunks = array_len.div_ceil(1024);
301    assert!(
302        n_chunks <= u32::MAX as usize,
303        "Cannot transpose patches for array with >= 4 trillion elements"
304    );
305
306    let n_lanes = patch_lanes::<V>();
307
308    // We know upfront how many indices and values we'll have.
309    let mut indices_buffer = BufferMut::with_capacity(indices_in.len());
310    let mut values_buffer = BufferMut::with_capacity(values_in.len());
311
312    // Number of patches in each chunk/lane.
313    let mut lane_offsets: BufferMut<u32> = BufferMut::zeroed(n_chunks * n_lanes + 1);
314
315    // Scan the index/value pairs once to get chunk/lane counts.
316    for index in indices_in {
317        let index = index.as_() - offset;
318        let chunk = index / 1024;
319        let lane = index % n_lanes;
320
321        lane_offsets[chunk * n_lanes + lane + 1] += 1;
322    }
323
324    for index in 1..lane_offsets.len() {
325        lane_offsets[index] += lane_offsets[index - 1];
326    }
327
328    // Loop over patches, writing them to final positions.
329    let indices_out = indices_buffer.spare_capacity_mut();
330    let values_out = values_buffer.spare_capacity_mut();
331    for (index, &value) in std::iter::zip(indices_in, values_in) {
332        let index = index.as_() - offset;
333        let chunk = index / 1024;
334        let lane = index % n_lanes;
335
336        let position = &mut lane_offsets[chunk * n_lanes + lane];
337        indices_out[*position as usize].write((index % 1024) as u16);
338        values_out[*position as usize].write(value);
339        *position += 1;
340    }
341
342    unsafe {
343        indices_buffer.set_len(indices_in.len());
344        values_buffer.set_len(values_in.len());
345    }
346
347    for index in indices_in {
348        let index = index.as_() - offset;
349        let chunk = index / 1024;
350        let lane = index % n_lanes;
351
352        lane_offsets[chunk * n_lanes + lane] -= 1;
353    }
354
355    TransposedPatches {
356        n_lanes,
357        lane_offsets: lane_offsets.freeze().into_byte_buffer(),
358        indices: indices_buffer.freeze().into_byte_buffer(),
359        values: values_buffer.freeze().into_byte_buffer(),
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use vortex_buffer::buffer;
366
367    use super::PatchedSlots;
368    use crate::ArrayRef;
369    use crate::IntoArray;
370    use crate::array_slots;
371    use crate::arrays::Null;
372    use crate::arrays::PrimitiveArray;
373    use crate::validity::Validity;
374
375    #[array_slots(Null)]
376    struct OptionalPatchedSlots {
377        required: ArrayRef,
378        maybe: Option<ArrayRef>,
379    }
380
381    #[test]
382    fn generated_slots_round_trip() {
383        let required = PrimitiveArray::new(buffer![1u8, 2, 3], Validity::NonNullable).into_array();
384        let optional = PrimitiveArray::new(buffer![4u8, 5, 6], Validity::NonNullable).into_array();
385
386        let slot_vec = vec![Some(required.clone()), Some(optional.clone())];
387        let view = OptionalPatchedSlotsView::from_slots(&slot_vec);
388        assert_eq!(view.required.len(), 3);
389        assert_eq!(view.maybe.expect("optional slot").len(), 3);
390
391        let cloned = OptionalPatchedSlots::from_slots(slot_vec.into());
392        assert_eq!(cloned.required.len(), required.len());
393        assert_eq!(cloned.maybe.expect("optional clone").len(), optional.len());
394
395        let rebuilt = PatchedSlots::from_slots(
396            vec![
397                Some(required.clone()),
398                Some(optional.clone()),
399                Some(required.clone()),
400                Some(optional.clone()),
401            ]
402            .into(),
403        );
404        assert_eq!(rebuilt.inner.len(), required.len());
405        assert_eq!(rebuilt.patch_values.len(), optional.len());
406    }
407}