Skip to main content

vortex_array/arrays/patched/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::ops::Range;
7
8use vortex_buffer::Buffer;
9use vortex_buffer::BufferMut;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::LEGACY_SESSION;
19use crate::VortexSessionExecute;
20use crate::array::Array;
21use crate::array::ArrayParts;
22use crate::array::TypedArrayRef;
23use crate::array_slots;
24use crate::arrays::Patched;
25use crate::arrays::PrimitiveArray;
26use crate::arrays::patched::TransposedPatches;
27use crate::arrays::patched::patch_lanes;
28use crate::buffer::BufferHandle;
29use crate::dtype::DType;
30use crate::dtype::IntegerPType;
31use crate::dtype::NativePType;
32use crate::dtype::PType;
33use crate::match_each_native_ptype;
34use crate::match_each_unsigned_integer_ptype;
35use crate::patches::Patches;
36use crate::validity::Validity;
37
38#[derive(Debug, Clone)]
39pub struct PatchedData {
40    /// Number of lanes the patch indices and values have been split into. Each of the `n_chunks`
41    /// of 1024 values is split into `n_lanes` lanes horizontally, each lane having 1024 / n_lanes
42    /// values that might be patched.
43    pub(super) n_lanes: usize,
44
45    /// The offset into that first chunk that is considered in bounds.
46    ///
47    /// The patch indices of the first chunk less than `offset` should be skipped, and the offset
48    /// should be subtracted out of the remaining offsets to get their final position in the
49    /// executed array.
50    pub(super) offset: usize,
51}
52
53#[array_slots(Patched)]
54pub struct PatchedSlots {
55    /// The inner array containing the base unpatched values.
56    pub inner: ArrayRef,
57    /// The lane offsets array for locating patches within lanes.
58    pub lane_offsets: ArrayRef,
59    /// The indices of patched (exception) values.
60    pub patch_indices: ArrayRef,
61    /// The patched (exception) values at the corresponding indices.
62    pub patch_values: ArrayRef,
63}
64
65impl Display for PatchedData {
66    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67        write!(f, "n_lanes: {}, offset: {}", self.n_lanes, self.offset)
68    }
69}
70
71impl PatchedData {
72    pub(crate) fn validate(
73        &self,
74        dtype: &DType,
75        len: usize,
76        slots: &PatchedSlotsView,
77    ) -> VortexResult<()> {
78        vortex_ensure!(
79            slots.inner.dtype() == dtype,
80            "PatchedArray base dtype {} does not match outer dtype {}",
81            slots.inner.dtype(),
82            dtype
83        );
84        vortex_ensure!(
85            slots.inner.len() == len,
86            "PatchedArray base len {} does not match outer len {}",
87            slots.inner.len(),
88            len
89        );
90        vortex_ensure!(
91            slots.patch_indices.len() == slots.patch_values.len(),
92            "PatchedArray patch indices len {} does not match patch values len {}",
93            slots.patch_indices.len(),
94            slots.patch_values.len()
95        );
96        Ok(())
97    }
98}
99
100pub trait PatchedArrayExt: PatchedArraySlotsExt {
101    #[inline]
102    fn n_lanes(&self) -> usize {
103        self.n_lanes
104    }
105
106    #[inline]
107    fn offset(&self) -> usize {
108        self.offset
109    }
110
111    #[inline]
112    fn lane_range(&self, chunk: usize, lane: usize) -> VortexResult<Range<usize>> {
113        assert!(chunk * 1024 <= self.as_ref().len() + self.offset());
114        assert!(lane < self.n_lanes());
115
116        let start = self.lane_offsets().execute_scalar(
117            chunk * self.n_lanes() + lane,
118            &mut LEGACY_SESSION.create_execution_ctx(),
119        )?;
120        let stop = self.lane_offsets().execute_scalar(
121            chunk * self.n_lanes() + lane + 1,
122            &mut LEGACY_SESSION.create_execution_ctx(),
123        )?;
124
125        let start = start
126            .as_primitive()
127            .as_::<usize>()
128            .ok_or_else(|| vortex_err!("could not cast lane_offset to usize"))?;
129
130        let stop = stop
131            .as_primitive()
132            .as_::<usize>()
133            .ok_or_else(|| vortex_err!("could not cast lane_offset to usize"))?;
134
135        Ok(start..stop)
136    }
137
138    fn slice_chunks(&self, chunks: Range<usize>) -> VortexResult<Array<Patched>> {
139        let lane_offsets_start = chunks.start * self.n_lanes();
140        let lane_offsets_stop = chunks.end * self.n_lanes() + 1;
141
142        let sliced_lane_offsets = self
143            .lane_offsets()
144            .slice(lane_offsets_start..lane_offsets_stop)?;
145        let indices = self.patch_indices().clone();
146        let values = self.patch_values().clone();
147
148        let begin = (chunks.start * 1024).saturating_sub(self.offset());
149        let end = (chunks.end * 1024)
150            .saturating_sub(self.offset())
151            .min(self.as_ref().len());
152
153        let offset = if chunks.start == 0 { self.offset() } else { 0 };
154        let inner = self.inner().slice(begin..end)?;
155        let len = inner.len();
156        let dtype = self.as_ref().dtype().clone();
157        let slots = PatchedSlots {
158            inner,
159            lane_offsets: sliced_lane_offsets,
160            patch_indices: indices,
161            patch_values: values,
162        }
163        .into_slots();
164
165        Ok(unsafe { Patched::new_unchecked(dtype, len, slots, self.n_lanes(), offset) })
166    }
167}
168
169impl<T: TypedArrayRef<Patched>> PatchedArrayExt for T {}
170
171impl Patched {
172    pub fn from_array_and_patches(
173        inner: ArrayRef,
174        patches: &Patches,
175        ctx: &mut ExecutionCtx,
176    ) -> VortexResult<Array<Patched>> {
177        vortex_ensure!(
178            inner.dtype().eq_with_nullability_superset(patches.dtype()),
179            "array DType must match patches DType"
180        );
181
182        vortex_ensure!(
183            inner.dtype().is_primitive(),
184            "Creating PatchedArray from Patches only supported for primitive arrays"
185        );
186
187        vortex_ensure!(
188            patches.num_patches() <= u32::MAX as usize,
189            "PatchedArray does not support > u32::MAX patch values"
190        );
191
192        vortex_ensure!(
193            patches.values().all_valid(ctx)?,
194            "PatchedArray cannot be built from Patches with nulls"
195        );
196
197        let values_ptype = patches.dtype().as_ptype();
198
199        let TransposedPatches {
200            n_lanes,
201            lane_offsets,
202            indices,
203            values,
204        } = transpose_patches(patches, ctx)?;
205
206        let lane_offsets = PrimitiveArray::from_buffer_handle(
207            BufferHandle::new_host(lane_offsets),
208            PType::U32,
209            Validity::NonNullable,
210        )
211        .into_array();
212        let indices = PrimitiveArray::from_buffer_handle(
213            BufferHandle::new_host(indices),
214            PType::U16,
215            Validity::NonNullable,
216        )
217        .into_array();
218        let values = PrimitiveArray::from_buffer_handle(
219            BufferHandle::new_host(values),
220            values_ptype,
221            Validity::NonNullable,
222        )
223        .into_array();
224
225        let dtype = inner.dtype().clone();
226        let len = inner.len();
227        let slots = PatchedSlots {
228            inner,
229            lane_offsets,
230            patch_indices: indices,
231            patch_values: values,
232        }
233        .into_slots();
234        Ok(unsafe { Self::new_unchecked(dtype, len, slots, n_lanes, 0) })
235    }
236
237    pub(crate) unsafe fn new_unchecked(
238        dtype: DType,
239        len: usize,
240        slots: Vec<Option<ArrayRef>>,
241        n_lanes: usize,
242        offset: usize,
243    ) -> Array<Patched> {
244        unsafe {
245            Array::from_parts_unchecked(
246                ArrayParts::new(Patched, dtype, len, PatchedData { n_lanes, offset })
247                    .with_slots(slots),
248            )
249        }
250    }
251}
252
253/// Transpose a set of patches from the default sorted layout into the data parallel layout.
254fn transpose_patches(patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<TransposedPatches> {
255    let array_len = patches.array_len();
256    let offset = patches.offset();
257
258    let indices = patches
259        .indices()
260        .clone()
261        .execute::<Canonical>(ctx)?
262        .into_primitive();
263
264    let values = patches
265        .values()
266        .clone()
267        .execute::<Canonical>(ctx)?
268        .into_primitive();
269
270    let indices_ptype = indices.ptype();
271    let values_ptype = values.ptype();
272
273    let indices = indices.buffer_handle().clone().unwrap_host();
274    let values = values.buffer_handle().clone().unwrap_host();
275
276    match_each_unsigned_integer_ptype!(indices_ptype, |I| {
277        match_each_native_ptype!(values_ptype, |V| {
278            let indices: Buffer<I> = Buffer::from_byte_buffer(indices);
279            let values: Buffer<V> = Buffer::from_byte_buffer(values);
280
281            Ok(transpose(
282                indices.as_slice(),
283                values.as_slice(),
284                offset,
285                array_len,
286            ))
287        })
288    })
289}
290
291#[expect(clippy::cast_possible_truncation)]
292fn transpose<I: IntegerPType, V: NativePType>(
293    indices_in: &[I],
294    values_in: &[V],
295    offset: usize,
296    array_len: usize,
297) -> TransposedPatches {
298    // Total number of slots is number of chunks times number of lanes.
299    let n_chunks = array_len.div_ceil(1024);
300    assert!(
301        n_chunks <= u32::MAX as usize,
302        "Cannot transpose patches for array with >= 4 trillion elements"
303    );
304
305    let n_lanes = patch_lanes::<V>();
306
307    // We know upfront how many indices and values we'll have.
308    let mut indices_buffer = BufferMut::with_capacity(indices_in.len());
309    let mut values_buffer = BufferMut::with_capacity(values_in.len());
310
311    // Number of patches in each chunk/lane.
312    let mut lane_offsets: BufferMut<u32> = BufferMut::zeroed(n_chunks * n_lanes + 1);
313
314    // Scan the index/value pairs once to get chunk/lane counts.
315    for index in indices_in {
316        let index = index.as_() - offset;
317        let chunk = index / 1024;
318        let lane = index % n_lanes;
319
320        lane_offsets[chunk * n_lanes + lane + 1] += 1;
321    }
322
323    for index in 1..lane_offsets.len() {
324        lane_offsets[index] += lane_offsets[index - 1];
325    }
326
327    // Loop over patches, writing them to final positions.
328    let indices_out = indices_buffer.spare_capacity_mut();
329    let values_out = values_buffer.spare_capacity_mut();
330    for (index, &value) in std::iter::zip(indices_in, values_in) {
331        let index = index.as_() - offset;
332        let chunk = index / 1024;
333        let lane = index % n_lanes;
334
335        let position = &mut lane_offsets[chunk * n_lanes + lane];
336        indices_out[*position as usize].write((index % 1024) as u16);
337        values_out[*position as usize].write(value);
338        *position += 1;
339    }
340
341    unsafe {
342        indices_buffer.set_len(indices_in.len());
343        values_buffer.set_len(values_in.len());
344    }
345
346    for index in indices_in {
347        let index = index.as_() - offset;
348        let chunk = index / 1024;
349        let lane = index % n_lanes;
350
351        lane_offsets[chunk * n_lanes + lane] -= 1;
352    }
353
354    TransposedPatches {
355        n_lanes,
356        lane_offsets: lane_offsets.freeze().into_byte_buffer(),
357        indices: indices_buffer.freeze().into_byte_buffer(),
358        values: values_buffer.freeze().into_byte_buffer(),
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use vortex_buffer::buffer;
365
366    use super::PatchedSlots;
367    use crate::ArrayRef;
368    use crate::IntoArray;
369    use crate::array_slots;
370    use crate::arrays::Null;
371    use crate::arrays::PrimitiveArray;
372    use crate::validity::Validity;
373
374    #[array_slots(Null)]
375    struct OptionalPatchedSlots {
376        required: ArrayRef,
377        maybe: Option<ArrayRef>,
378    }
379
380    #[test]
381    fn generated_slots_round_trip() {
382        let required = PrimitiveArray::new(buffer![1u8, 2, 3], Validity::NonNullable).into_array();
383        let optional = PrimitiveArray::new(buffer![4u8, 5, 6], Validity::NonNullable).into_array();
384
385        let slot_vec = vec![Some(required.clone()), Some(optional.clone())];
386        let view = OptionalPatchedSlotsView::from_slots(&slot_vec);
387        assert_eq!(view.required.len(), 3);
388        assert_eq!(view.maybe.expect("optional slot").len(), 3);
389
390        let cloned = OptionalPatchedSlots::from_slots(slot_vec);
391        assert_eq!(cloned.required.len(), required.len());
392        assert_eq!(cloned.maybe.expect("optional clone").len(), optional.len());
393
394        let rebuilt = PatchedSlots::from_slots(vec![
395            Some(required.clone()),
396            Some(optional.clone()),
397            Some(required.clone()),
398            Some(optional.clone()),
399        ]);
400        assert_eq!(rebuilt.inner.len(), required.len());
401        assert_eq!(rebuilt.patch_values.len(), optional.len());
402    }
403}